aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-08-12 20:48:56 +0100
committerJames Ward <james.ward@arm.com>2022-10-11 11:56:02 +0100
commit8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch)
treefea519246b698eb944b9d58537fc90bc30481d11 /verif/generator/tosa_utils.py
parentba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff)
downloadreference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 6a689d0..7fa31e7 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -84,3 +84,42 @@ def product(shape):
for n in shape:
value *= n
return value
+
+
+def get_accum_dtype_from_tgTypes(dtypes):
+ # Get accumulate data-type from the test generator's defined types
+ if isinstance(dtypes, list) or isinstance(dtypes, tuple):
+ return dtypes[-1]
+ else:
+ return dtypes
+
+
+def get_wrong_output_type(op_name, rng, input_dtype):
+ if op_name == "fully_connected" or op_name == "matmul":
+ if input_dtype == DType.INT8:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.INT16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.FLOAT or input_dtype == DType.FP16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
+ return rng.choice(a=incorrect_types)