diff options
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 6a689d0..7fa31e7 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -84,3 +84,42 @@ def product(shape): for n in shape: value *= n return value + + +def get_accum_dtype_from_tgTypes(dtypes): + # Get accumulate data-type from the test generator's defined types + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + return dtypes[-1] + else: + return dtypes + + +def get_wrong_output_type(op_name, rng, input_dtype): + if op_name == "fully_connected" or op_name == "matmul": + if input_dtype == DType.INT8: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.INT16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.FLOAT or input_dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) + return rng.choice(a=incorrect_types) |