aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-10 14:16:39 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-01-30 11:50:54 +0000
commit95a6710ffb8cadcb8658a967ab29cac1bffad930 (patch)
tree6320e5d34441626b1e7a956886bd1fee88dbf4a1 /verif/generator/tosa_test_gen.py
parent4f931307a6319d9d99b3afce4ca6e1cd30d77f01 (diff)
downloadreference_model-95a6710ffb8cadcb8658a967ab29cac1bffad930.tar.gz
Main Compliance: TRANSPOSE_CONV2D support
Update data generator for main compliance values. Add test generation support. Fixed test set by including large 65k tests that were missing. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I8668c774e01c17e5d999aadf99c317e2dd893857
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py43
1 files changed, 27 insertions, 16 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 6867979..39b064d 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -324,6 +324,7 @@ class TosaTestGen:
Op.CONV2D,
Op.FULLY_CONNECTED,
Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
)
if (
errorName
@@ -987,19 +988,21 @@ class TosaTestGen:
def build_transpose_conv2d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- stride,
- out_pad,
- output_shape,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ out_pad = args_dict["pad"]
+ output_shape = args_dict["out_shape"]
+
assert len(out_pad) == 4
- result_tens = OutputShaper.transposeConv2DOp(
+ result_tensor = OutputShaper.transposeConv2DOp(
self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
)
@@ -1010,12 +1013,12 @@ class TosaTestGen:
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
- TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
@@ -1028,16 +1031,16 @@ class TosaTestGen:
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
output_list=output_list,
pad=out_pad,
- stride=stride,
+ stride=strides,
input_shape=ifm.shape,
weight_shape=filter.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
):
return None
@@ -1046,11 +1049,16 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(
- out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound
+ out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, ifm.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_depthwise_conv2d(
self,
@@ -3307,7 +3315,7 @@ class TosaTestGen:
"build_fcn": (
build_transpose_conv2d,
TosaTensorGen.tgTransposeConv2D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTransposeConv2D,
),
"qgen": TosaQuantGen.qgConv,
@@ -3328,6 +3336,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
"template": True,
},
# Activation functions