diff options
author | James Ward <james.ward@arm.com> | 2022-08-12 20:48:56 +0100 |
---|---|---|
committer | James Ward <james.ward@arm.com> | 2022-10-11 11:56:02 +0100 |
commit | 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch) | |
tree | fea519246b698eb944b9d58537fc90bc30481d11 /verif/generator/tosa_utils.py | |
parent | ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff) | |
download | reference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz |
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 6a689d0..7fa31e7 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -84,3 +84,42 @@ def product(shape): for n in shape: value *= n return value + + +def get_accum_dtype_from_tgTypes(dtypes): + # Get accumulate data-type from the test generator's defined types + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + return dtypes[-1] + else: + return dtypes + + +def get_wrong_output_type(op_name, rng, input_dtype): + if op_name == "fully_connected" or op_name == "matmul": + if input_dtype == DType.INT8: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.INT16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.FLOAT or input_dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) + return rng.choice(a=incorrect_types) |