diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7702753..c867070 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -896,6 +896,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -903,7 +904,9 @@ class TosaTestGen: local_bound = False attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) + attr.ConvAttribute( + padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -981,6 +984,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -988,7 +992,9 @@ class TosaTestGen: local_bound = False attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) + attr.ConvAttribute( + padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -1057,6 +1063,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -1065,7 +1072,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute( - out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound + out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -1143,6 +1150,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -1150,7 +1158,9 @@ class TosaTestGen: local_bound = False attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) + attr.ConvAttribute( + padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -3385,6 +3395,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -3418,6 +3429,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -3452,6 +3464,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -3564,6 +3577,7 @@ class TosaTestGen: TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -5290,6 +5304,18 @@ class OutputShaper: return ser.addOutput(shape, outputDType) @staticmethod + def _get_conv_output_type(input_dtype): + if input_dtype in (DType.FP16, DType.BF16, DType.FP32): + return input_dtype + elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2): + return DType.FP16 + elif input_dtype in (DType.INT8, DType.INT4): + return DType.INT32 + elif input_dtype in (DType.INT16,): + return DType.INT48 + assert True, f"Unsupported convolution data type {input_dtype}" + + @staticmethod def conv2dOp( ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None ): @@ -5329,7 +5355,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: @@ -5393,7 +5419,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: @@ -5444,7 +5470,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: @@ -5958,7 +5984,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: |