aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/ops/tensor_ops.cc15
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_error_if.py6
-rw-r--r--verif/generator/tosa_test_gen.py28
4 files changed, 19 insertions, 30 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index edc1793..f38f486 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1958,12 +1958,6 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttr
return 1;
}
- if (attribute->output_shape().size() != 4)
- {
- printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
- return 1;
- }
-
for (int32_t i : attribute->stride())
{
if (i < 1)
@@ -1973,15 +1967,6 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttr
}
}
- for (int d = 0; d < 4; d++)
- {
- if (attribute->output_shape()[d] != this->output->getShape()[d])
- {
- printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
- return 1;
- }
- }
-
int32_t IH = input->getShape()[1];
int32_t IW = input->getShape()[2];
int32_t OH = output->getShape()[1];
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 57d781883142db8a45fe98ac1a1dfacc49cba78
+Subproject 50256e168c3e759f03445bb872d0a43da1a6ba3
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 53a3199..9fd13d2 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -2788,7 +2788,6 @@ class TosaInvalidValidator:
if opName.startswith("transpose_conv2d"):
# transpose_conv2d
- output_shape = args_dict["out_shape"]
filter_shape = inputShapes[1]
kernel_shape = filter_shape[1:-1]
@@ -2810,10 +2809,7 @@ class TosaInvalidValidator:
padding[2],
padding[3],
)
- if output_shape[1] == h and output_shape[2] == w:
- return False
- # output shape does not match the expected shape
- return True
+ return h < 1 or w < 1
if "conv2d" in opName or "conv3d" in opName:
# conv2d, conv3d, depthwise_conv2d
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b85dd03..88dd17a 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -264,9 +264,9 @@ class TosaTestGen:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
"s": argsDict["s"],
- "ks": int(argsDict["ksb"])
- if "ksb" in argsDict
- else int(argsDict["ks"]),
+ "ks": (
+ int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
+ ),
}
elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
mode = gtu.ComplianceMode.FP_SPECIAL
@@ -1028,11 +1028,10 @@ class TosaTestGen:
accum_dtype = args_dict["acc_type"]
strides = args_dict["stride"]
out_pad = args_dict["pad"]
- output_shape = args_dict["out_shape"]
assert len(out_pad) == 4
result_tensor = OutputShaper.transposeConv2DOp(
- self.ser, rng, ifm, output_shape, accum_dtype, error_name
+ self.ser, rng, ifm, filter, accum_dtype, strides, out_pad, error_name
)
# Ensure new output type has correct qinfo
@@ -1081,7 +1080,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(
- out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
+ out_pad, strides, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -5942,14 +5941,23 @@ class OutputShaper:
return ser.addOutput(val.shape, out_dtype)
@staticmethod
- def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
+ def transposeConv2DOp(
+ ser, rng, ifm, filter, accum_dtype, strides, padding, error_name=None
+ ):
+
+ h = (ifm.shape[1] - 1) * strides[0] + padding[0] + padding[1] + filter.shape[1]
+
+ w = (ifm.shape[2] - 1) * strides[1] + padding[2] + padding[3] + filter.shape[2]
+
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
if change in [1, 3]:
- output_shape[1] = output_shape[1] + rng.choice(choices)
+ h = h + rng.choice(choices)
if change in [2, 3]:
- output_shape[2] = output_shape[2] + rng.choice(choices)
+ w = w + rng.choice(choices)
+
+ ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
@@ -5965,7 +5973,7 @@ class OutputShaper:
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
- return ser.addOutput(output_shape, out_dtype)
+ return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):