aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-14 16:21:29 +0000
committerEric Kunze <eric.kunze@arm.com>2024-03-20 00:02:15 +0000
commitf36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch)
treeb1aa6a7314ef598561f0259c4d614a4169451031 /verif/generator/tosa_error_if.py
parent0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff)
downloadreference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly <tai.ly@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e557f06..916b4f9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -649,9 +649,9 @@ class TosaErrorValidator:
or input_dtype == DType.INT16
and output_dtype != DType.INT48
or input_dtype == DType.FP16
- and output_dtype not in (DType.FP16, DType.FP32)
+ and output_dtype != DType.FP16
or input_dtype == DType.BF16
- and output_dtype != DType.FP32
+ and output_dtype != DType.BF16
or input_dtype == DType.FP32
and output_dtype != DType.FP32
or input_dtype == DType.FP8E4M3
@@ -2682,6 +2682,36 @@ class TosaErrorValidator:
):
error_result = True
+ elif op["op"] in {
+ Op.CONV2D,
+ Op.CONV3D,
+ Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
+ }:
+ if input_dtype == DType.INT8 and accum_dtype != DType.INT32:
+ error_result = True
+ elif input_dtype == DType.INT16 and accum_dtype != DType.INT48:
+ error_result = True
+ elif (
+ input_dtype
+ in (
+ DType.FP32,
+ DType.BF16,
+ )
+ and accum_dtype != DType.FP32
+ ):
+ error_result = True
+ elif input_dtype == DType.FP16 and accum_dtype not in (
+ DType.FP16,
+ DType.FP32,
+ ):
+ error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
+
info_dict = {
"error_name": error_name,
"error_result": error_result,