From b5fcfc0cfeffeb3fe9f05c32dc678ab09b2cee31 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Tue, 16 Apr 2024 16:14:36 -0700 Subject: [reference_model] Remove output_shape from transpose_conv2d Signed-off-by: Suraj Sudhir Change-Id: Ib2b95e73b226d64c4db5ad1ed22c123e04d7e6f9 --- reference_model/src/ops/tensor_ops.cc | 15 --------------- thirdparty/serialization_lib | 2 +- verif/generator/tosa_error_if.py | 6 +----- verif/generator/tosa_test_gen.py | 28 ++++++++++++++++++---------- 4 files changed, 20 insertions(+), 31 deletions(-) diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index edc1793..f38f486 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1958,12 +1958,6 @@ int OpTransposeConv2d::checkTensorAttr return 1; } - if (attribute->output_shape().size() != 4) - { - printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape"); - return 1; - } - for (int32_t i : attribute->stride()) { if (i < 1) @@ -1973,15 +1967,6 @@ int OpTransposeConv2d::checkTensorAttr } } - for (int d = 0; d < 4; d++) - { - if (attribute->output_shape()[d] != this->output->getShape()[d]) - { - printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape"); - return 1; - } - } - int32_t IH = input->getShape()[1]; int32_t IW = input->getShape()[2]; int32_t OH = output->getShape()[1]; diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 57d7818..50256e1 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 57d781883142db8a45fe98ac1a1dfacc49cba78a +Subproject commit 50256e168c3e759f03445bb872d0a43da1a6ba30 diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 53a3199..9fd13d2 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -2788,7 +2788,6 @@ class TosaInvalidValidator: if opName.startswith("transpose_conv2d"): # transpose_conv2d - output_shape = args_dict["out_shape"] filter_shape = inputShapes[1] kernel_shape = filter_shape[1:-1] @@ -2810,10 +2809,7 @@ class TosaInvalidValidator: padding[2], padding[3], ) - if output_shape[1] == h and output_shape[2] == w: - return False - # output shape does not match the expected shape - return True + return h < 1 or w < 1 if "conv2d" in opName or "conv3d" in opName: # conv2d, conv3d, depthwise_conv2d diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b85dd03..88dd17a 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -264,9 +264,9 @@ class TosaTestGen: mode = gtu.ComplianceMode.DOT_PRODUCT compliance_tens["dot_product_info"] = { "s": argsDict["s"], - "ks": int(argsDict["ksb"]) - if "ksb" in argsDict - else int(argsDict["ks"]), + "ks": ( + int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"]) + ), } elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: mode = gtu.ComplianceMode.FP_SPECIAL @@ -1028,11 +1028,10 @@ class TosaTestGen: accum_dtype = args_dict["acc_type"] strides = args_dict["stride"] out_pad = args_dict["pad"] - output_shape = args_dict["out_shape"] assert len(out_pad) == 4 result_tensor = OutputShaper.transposeConv2DOp( - self.ser, rng, ifm, output_shape, accum_dtype, error_name + self.ser, rng, ifm, filter, accum_dtype, strides, out_pad, error_name ) # Ensure new output type has correct qinfo @@ -1081,7 +1080,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute( - out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype + out_pad, strides, qinfo[0], qinfo[1], local_bound, accum_dtype ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -5942,14 +5941,23 @@ class OutputShaper: return ser.addOutput(val.shape, out_dtype) @staticmethod - def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None): + def transposeConv2DOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, error_name=None + ): + + h = (ifm.shape[1] - 1) * strides[0] + padding[0] + padding[1] + filter.shape[1] + + w = (ifm.shape[2] - 1) * strides[1] + padding[2] + padding[3] + filter.shape[2] + if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) if change in [1, 3]: - output_shape[1] = output_shape[1] + rng.choice(choices) + h = h + rng.choice(choices) if change in [2, 3]: - output_shape[2] = output_shape[2] + rng.choice(choices) + w = w + rng.choice(choices) + + ofm_shape = [ifm.shape[0], h, w, filter.shape[0]] if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect @@ -5965,7 +5973,7 @@ class OutputShaper: wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) - return ser.addOutput(output_shape, out_dtype) + return ser.addOutput(ofm_shape, out_dtype) @staticmethod def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None): -- cgit v1.2.1