From 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 Mon Sep 17 00:00:00 2001 From: James Ward Date: Fri, 12 Aug 2022 20:48:56 +0100 Subject: Reference model changes for fp16 support Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward --- verif/generator/tosa_utils.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) (limited to 'verif/generator/tosa_utils.py') diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 6a689d0..7fa31e7 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -84,3 +84,42 @@ def product(shape): for n in shape: value *= n return value + + +def get_accum_dtype_from_tgTypes(dtypes): + # Get accumulate data-type from the test generator's defined types + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + return dtypes[-1] + else: + return dtypes + + +def get_wrong_output_type(op_name, rng, input_dtype): + if op_name == "fully_connected" or op_name == "matmul": + if input_dtype == DType.INT8: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.INT16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.FLOAT or input_dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) + return rng.choice(a=incorrect_types) -- cgit v1.2.1