From 95a6710ffb8cadcb8658a967ab29cac1bffad930 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 10 Jan 2024 14:16:39 +0000 Subject: Main Compliance: TRANSPOSE_CONV2D support Update data generator for main compliance values. Add test generation support. Fixed test set by including large 65k tests that were missing. Signed-off-by: Jeremy Johnson Change-Id: I8668c774e01c17e5d999aadf99c317e2dd893857 --- verif/generator/tosa_test_gen.py | 43 +++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 16 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 6867979..39b064d 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -324,6 +324,7 @@ class TosaTestGen: Op.CONV2D, Op.FULLY_CONNECTED, Op.DEPTHWISE_CONV2D, + Op.TRANSPOSE_CONV2D, ) if ( errorName @@ -987,19 +988,21 @@ class TosaTestGen: def build_transpose_conv2d( self, op, - ifm, - filter, - bias, - accum_dtype, - stride, - out_pad, - output_shape, + inputs, + args_dict, validator_fcns=None, error_name=None, qinfo=None, ): + assert len(inputs) == 3 + ifm, filter, bias = inputs + accum_dtype = args_dict["acc_type"] + strides = args_dict["stride"] + out_pad = args_dict["pad"] + output_shape = args_dict["out_shape"] + assert len(out_pad) == 4 - result_tens = OutputShaper.transposeConv2DOp( + result_tensor = OutputShaper.transposeConv2DOp( self.ser, self.rng, ifm, output_shape, accum_dtype, error_name ) @@ -1010,12 +1013,12 @@ class TosaTestGen: ): qinfo = [ TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tens.dtype), + TosaQuantGen.getZeroPoint(self, result_tensor.dtype), ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list @@ -1028,16 +1031,16 @@ class TosaTestGen: op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, - output_dtype=result_tens.dtype, + output_dtype=result_tensor.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, output_list=output_list, pad=out_pad, - stride=stride, + stride=strides, input_shape=ifm.shape, weight_shape=filter.shape, - output_shape=result_tens.shape, + output_shape=result_tensor.shape, ): return None @@ -1046,11 +1049,16 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute( - out_pad, stride, output_shape, qinfo[0], qinfo[1], local_bound + out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound ) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, ifm.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_depthwise_conv2d( self, @@ -3307,7 +3315,7 @@ class TosaTestGen: "build_fcn": ( build_transpose_conv2d, TosaTensorGen.tgTransposeConv2D, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agTransposeConv2D, ), "qgen": TosaQuantGen.qgConv, @@ -3328,6 +3336,9 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, "template": True, }, # Activation functions -- cgit v1.2.1