diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-14 16:21:29 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-20 00:02:15 +0000 |
commit | f36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch) | |
tree | b1aa6a7314ef598561f0259c4d614a4169451031 /verif/generator/tosa_utils.py | |
parent | 0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff) | |
download | reference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz |
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in
ConvAttribute and TransposeConvAttribute
Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index cfe7cc6..4a4f6bb 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -164,10 +164,18 @@ def product(shape): return value -def get_accum_dtype_from_tgTypes(dtypes): - # Get accumulate data-type from the test generator's defined types +def get_accum_dtypes_from_tgTypes(dtypes): + # Get accumulate data-types from the test generator's defined types assert isinstance(dtypes, list) or isinstance(dtypes, tuple) - return dtypes[-1] + input_dtype = dtypes[0] + output_dtype = dtypes[-1] + # by default, accum_dtypes contains only output_dtype + accum_dtypes = [output_dtype] + if input_dtype == DType.FP16 and output_dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FP32] + elif output_dtype == DType.BF16: + accum_dtypes = [DType.FP32] + return accum_dtypes def get_wrong_output_type(op_name, rng, input_dtype): |