From b5fcfc0cfeffeb3fe9f05c32dc678ab09b2cee31 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Tue, 16 Apr 2024 16:14:36 -0700 Subject: [reference_model] Remove output_shape from transpose_conv2d Signed-off-by: Suraj Sudhir Change-Id: Ib2b95e73b226d64c4db5ad1ed22c123e04d7e6f9 --- verif/generator/tosa_test_gen.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b85dd03..88dd17a 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -264,9 +264,9 @@ class TosaTestGen: mode = gtu.ComplianceMode.DOT_PRODUCT compliance_tens["dot_product_info"] = { "s": argsDict["s"], - "ks": int(argsDict["ksb"]) - if "ksb" in argsDict - else int(argsDict["ks"]), + "ks": ( + int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"]) + ), } elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL: mode = gtu.ComplianceMode.FP_SPECIAL @@ -1028,11 +1028,10 @@ class TosaTestGen: accum_dtype = args_dict["acc_type"] strides = args_dict["stride"] out_pad = args_dict["pad"] - output_shape = args_dict["out_shape"] assert len(out_pad) == 4 result_tensor = OutputShaper.transposeConv2DOp( - self.ser, rng, ifm, output_shape, accum_dtype, error_name + self.ser, rng, ifm, filter, accum_dtype, strides, out_pad, error_name ) # Ensure new output type has correct qinfo @@ -1081,7 +1080,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute( - out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype + out_pad, strides, qinfo[0], qinfo[1], local_bound, accum_dtype ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -5942,14 +5941,23 @@ class OutputShaper: return ser.addOutput(val.shape, out_dtype) @staticmethod - def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None): + def transposeConv2DOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, error_name=None + ): + + h = (ifm.shape[1] - 1) * strides[0] + padding[0] + padding[1] + filter.shape[1] + + w = (ifm.shape[2] - 1) * strides[1] + padding[2] + padding[3] + filter.shape[2] + if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) if change in [1, 3]: - output_shape[1] = output_shape[1] + rng.choice(choices) + h = h + rng.choice(choices) if change in [2, 3]: - output_shape[2] = output_shape[2] + rng.choice(choices) + w = w + rng.choice(choices) + + ofm_shape = [ifm.shape[0], h, w, filter.shape[0]] if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect @@ -5965,7 +5973,7 @@ class OutputShaper: wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) - return ser.addOutput(output_shape, out_dtype) + return ser.addOutput(ofm_shape, out_dtype) @staticmethod def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None): -- cgit v1.2.1