diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-01-04 17:05:24 +0000 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-01-30 11:49:56 +0000 |
commit | 4f931307a6319d9d99b3afce4ca6e1cd30d77f01 (patch) | |
tree | 5661b63bd087b210403e3b50dbc0ce0a9f8a41b4 /verif/generator/tosa_test_gen.py | |
parent | 2d7e4b13d2c3022ae8176d59e2a11d5584ea1d0b (diff) | |
download | reference_model-4f931307a6319d9d99b3afce4ca6e1cd30d77f01.tar.gz |
Main Compliance: DEPTHWISE_CONV2D support
Added DEPTHWISE_CONV2D data generation.
Updated test generation for FP16 and FP32.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I0471d0a1e4e279a27233f4d285082906ceea1bff
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 49d9f1b..6867979 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -318,8 +318,13 @@ class TosaTestGen: def tensorComplianceMetaData( self, op, inputType, argsDict, outputTensor, errorName ): - # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet - UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED) + # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet + UNSUPPORTED_NON_FP32_INPUT_OPS = ( + Op.MATMUL, + Op.CONV2D, + Op.FULLY_CONNECTED, + Op.DEPTHWISE_CONV2D, + ) if ( errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype) @@ -1063,7 +1068,7 @@ class TosaTestGen: padding = args_dict["pad"] dilations = args_dict["dilation"] - result_tens = OutputShaper.depthwiseConv2dOp( + result_tensor = OutputShaper.depthwiseConv2dOp( self.ser, self.rng, ifm, @@ -1082,12 +1087,12 @@ class TosaTestGen: ): qinfo = [ TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tens.dtype), + TosaQuantGen.getZeroPoint(self, result_tensor.dtype), ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list @@ -1100,7 +1105,7 @@ class TosaTestGen: op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, - output_dtype=result_tens.dtype, + output_dtype=result_tensor.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, @@ -1110,7 +1115,7 @@ class TosaTestGen: dilation=dilations, input_shape=ifm.shape, weight_shape=filter.shape, - output_shape=result_tens.shape, + output_shape=result_tensor.shape, ): return None @@ -1121,7 +1126,12 @@ class TosaTestGen: attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, ifm.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_fully_connected( self, @@ -3206,6 +3216,9 @@ class TosaTestGen: TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, "template": True, }, "fully_connected": { |