aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index cfe7cc6..4a4f6bb 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -164,10 +164,18 @@ def product(shape):
return value
-def get_accum_dtype_from_tgTypes(dtypes):
- # Get accumulate data-type from the test generator's defined types
+def get_accum_dtypes_from_tgTypes(dtypes):
+ # Get accumulate data-types from the test generator's defined types
assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
- return dtypes[-1]
+ input_dtype = dtypes[0]
+ output_dtype = dtypes[-1]
+ # by default, accum_dtypes contains only output_dtype
+ accum_dtypes = [output_dtype]
+ if input_dtype == DType.FP16 and output_dtype == DType.FP16:
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif output_dtype == DType.BF16:
+ accum_dtypes = [DType.FP32]
+ return accum_dtypes
def get_wrong_output_type(op_name, rng, input_dtype):