aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py42
1 files changed, 34 insertions, 8 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7702753..c867070 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -896,6 +896,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -903,7 +904,9 @@ class TosaTestGen:
local_bound = False
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+ attr.ConvAttribute(
+ padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+ )
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -981,6 +984,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -988,7 +992,9 @@ class TosaTestGen:
local_bound = False
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+ attr.ConvAttribute(
+ padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+ )
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -1057,6 +1063,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -1065,7 +1072,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(
- out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
+ out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -1143,6 +1150,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -1150,7 +1158,9 @@ class TosaTestGen:
local_bound = False
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+ attr.ConvAttribute(
+ padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+ )
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -3385,6 +3395,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3418,6 +3429,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3452,6 +3464,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3564,6 +3577,7 @@ class TosaTestGen:
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -5290,6 +5304,18 @@ class OutputShaper:
return ser.addOutput(shape, outputDType)
@staticmethod
+ def _get_conv_output_type(input_dtype):
+ if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
+ return input_dtype
+ elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
+ return DType.FP16
+ elif input_dtype in (DType.INT8, DType.INT4):
+ return DType.INT32
+ elif input_dtype in (DType.INT16,):
+ return DType.INT48
+ assert True, f"Unsupported convolution data type {input_dtype}"
+
+ @staticmethod
def conv2dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
@@ -5329,7 +5355,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
@@ -5393,7 +5419,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
@@ -5444,7 +5470,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
@@ -5958,7 +5984,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16: