diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-14 16:21:29 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-20 00:02:15 +0000 |
commit | f36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch) | |
tree | b1aa6a7314ef598561f0259c4d614a4169451031 /verif/generator/tosa_error_if.py | |
parent | 0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff) | |
download | reference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz |
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in
ConvAttribute and TransposeConvAttribute
Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index e557f06..916b4f9 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -649,9 +649,9 @@ class TosaErrorValidator: or input_dtype == DType.INT16 and output_dtype != DType.INT48 or input_dtype == DType.FP16 - and output_dtype not in (DType.FP16, DType.FP32) + and output_dtype != DType.FP16 or input_dtype == DType.BF16 - and output_dtype != DType.FP32 + and output_dtype != DType.BF16 or input_dtype == DType.FP32 and output_dtype != DType.FP32 or input_dtype == DType.FP8E4M3 @@ -2682,6 +2682,36 @@ class TosaErrorValidator: ): error_result = True + elif op["op"] in { + Op.CONV2D, + Op.CONV3D, + Op.DEPTHWISE_CONV2D, + Op.TRANSPOSE_CONV2D, + }: + if input_dtype == DType.INT8 and accum_dtype != DType.INT32: + error_result = True + elif input_dtype == DType.INT16 and accum_dtype != DType.INT48: + error_result = True + elif ( + input_dtype + in ( + DType.FP32, + DType.BF16, + ) + and accum_dtype != DType.FP32 + ): + error_result = True + elif input_dtype == DType.FP16 and accum_dtype not in ( + DType.FP16, + DType.FP32, + ): + error_result = True + elif ( + input_dtype in (DType.FP8E4M3, DType.FP8E5M2) + and accum_dtype != DType.FP16 + ): + error_result = True + info_dict = { "error_name": error_name, "error_result": error_result, |