aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b85dd03..88dd17a 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -264,9 +264,9 @@ class TosaTestGen:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
"s": argsDict["s"],
- "ks": int(argsDict["ksb"])
- if "ksb" in argsDict
- else int(argsDict["ks"]),
+ "ks": (
+ int(argsDict["ksb"]) if "ksb" in argsDict else int(argsDict["ks"])
+ ),
}
elif argsDict["dg_type"] == gtu.DataGenType.FP_SPECIAL:
mode = gtu.ComplianceMode.FP_SPECIAL
@@ -1028,11 +1028,10 @@ class TosaTestGen:
accum_dtype = args_dict["acc_type"]
strides = args_dict["stride"]
out_pad = args_dict["pad"]
- output_shape = args_dict["out_shape"]
assert len(out_pad) == 4
result_tensor = OutputShaper.transposeConv2DOp(
- self.ser, rng, ifm, output_shape, accum_dtype, error_name
+ self.ser, rng, ifm, filter, accum_dtype, strides, out_pad, error_name
)
# Ensure new output type has correct qinfo
@@ -1081,7 +1080,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(
- out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
+ out_pad, strides, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -5942,14 +5941,23 @@ class OutputShaper:
return ser.addOutput(val.shape, out_dtype)
@staticmethod
- def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
+ def transposeConv2DOp(
+ ser, rng, ifm, filter, accum_dtype, strides, padding, error_name=None
+ ):
+
+ h = (ifm.shape[1] - 1) * strides[0] + padding[0] + padding[1] + filter.shape[1]
+
+ w = (ifm.shape[2] - 1) * strides[1] + padding[2] + padding[3] + filter.shape[2]
+
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
if change in [1, 3]:
- output_shape[1] = output_shape[1] + rng.choice(choices)
+ h = h + rng.choice(choices)
if change in [2, 3]:
- output_shape[2] = output_shape[2] + rng.choice(choices)
+ w = w + rng.choice(choices)
+
+ ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
@@ -5965,7 +5973,7 @@ class OutputShaper:
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
- return ser.addOutput(output_shape, out_dtype)
+ return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):