aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e557f06..916b4f9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -649,9 +649,9 @@ class TosaErrorValidator:
or input_dtype == DType.INT16
and output_dtype != DType.INT48
or input_dtype == DType.FP16
- and output_dtype not in (DType.FP16, DType.FP32)
+ and output_dtype != DType.FP16
or input_dtype == DType.BF16
- and output_dtype != DType.FP32
+ and output_dtype != DType.BF16
or input_dtype == DType.FP32
and output_dtype != DType.FP32
or input_dtype == DType.FP8E4M3
@@ -2682,6 +2682,36 @@ class TosaErrorValidator:
):
error_result = True
+ elif op["op"] in {
+ Op.CONV2D,
+ Op.CONV3D,
+ Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
+ }:
+ if input_dtype == DType.INT8 and accum_dtype != DType.INT32:
+ error_result = True
+ elif input_dtype == DType.INT16 and accum_dtype != DType.INT48:
+ error_result = True
+ elif (
+ input_dtype
+ in (
+ DType.FP32,
+ DType.BF16,
+ )
+ and accum_dtype != DType.FP32
+ ):
+ error_result = True
+ elif input_dtype == DType.FP16 and accum_dtype not in (
+ DType.FP16,
+ DType.FP32,
+ ):
+ error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
+
info_dict = {
"error_name": error_name,
"error_result": error_result,