From 47ab1762d1c15a7b4c0c068d7294111c5c5f92a2 Mon Sep 17 00:00:00 2001 From: evacha01 Date: Mon, 29 Jan 2024 13:23:23 +0000 Subject: Main Compliance testing for CONV3D Signed-off-by: evacha01 Change-Id: Ie05f88db15cd07fd5483ab669329d7048bd3349c --- verif/generator/tosa_test_gen.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'verif/generator') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 68a4e94..9c3cd32 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -325,6 +325,7 @@ class TosaTestGen: Op.FULLY_CONNECTED, Op.DEPTHWISE_CONV2D, Op.TRANSPOSE_CONV2D, + Op.CONV3D, ) if ( errorName @@ -952,7 +953,7 @@ class TosaTestGen: dilations = args_dict["dilation"] assert len(padding) == 6 - result_tens = OutputShaper.conv3dOp( + result_tensor = OutputShaper.conv3dOp( self.ser, self.rng, ifm, @@ -971,12 +972,12 @@ class TosaTestGen: ): qinfo = [ TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tens.dtype), + TosaQuantGen.getZeroPoint(self, result_tensor.dtype), ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list @@ -989,7 +990,7 @@ class TosaTestGen: op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, - output_dtype=result_tens.dtype, + output_dtype=result_tensor.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, @@ -999,7 +1000,7 @@ class TosaTestGen: dilation=dilations, input_shape=ifm.shape, weight_shape=filter.shape, - output_shape=result_tens.shape, + output_shape=result_tensor.shape, ): return None @@ -1010,7 +1011,12 @@ class TosaTestGen: attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, ifm.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_transpose_conv2d( self, @@ -3254,6 +3260,9 @@ class TosaTestGen: TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, "template": True, }, # Templated operator. Filled in by createDynamicOpLists -- cgit v1.2.1