aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 6a689d0..7fa31e7 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -84,3 +84,42 @@ def product(shape):
for n in shape:
value *= n
return value
+
+
+def get_accum_dtype_from_tgTypes(dtypes):
+ # Get accumulate data-type from the test generator's defined types
+ if isinstance(dtypes, list) or isinstance(dtypes, tuple):
+ return dtypes[-1]
+ else:
+ return dtypes
+
+
+def get_wrong_output_type(op_name, rng, input_dtype):
+ if op_name == "fully_connected" or op_name == "matmul":
+ if input_dtype == DType.INT8:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.INT16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.FLOAT or input_dtype == DType.FP16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
+ return rng.choice(a=incorrect_types)