aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py72
1 files changed, 68 insertions, 4 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 9a88acb..7a4d0d6 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -325,12 +325,32 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP32]:
+ # if input_dtype in [DType.BOOL, DType.FP32]:
+ # outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ if input_dtype in [DType.BOOL]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT48,
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif input_dtype in [DType.FP32]:
outputDType = [DType.BOOL, DType.INT48, DType.FP32]
elif input_dtype in [DType.FP16, DType.BF16]:
outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
+ elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ ]
else:
assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -476,13 +496,23 @@ class TosaErrorValidator:
)
or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
+ or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
+ or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
input_dtype
- in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ in [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
and output_dtype != DType.INT32
):
error_result = True
@@ -555,12 +585,26 @@ class TosaErrorValidator:
or (
input_dtype == DType.FP16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.BF16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.FP32
@@ -571,6 +615,17 @@ class TosaErrorValidator:
DType.INT32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ )
+ or (
+ input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
+ and output_dtype
+ not in [
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
]
)
):
@@ -597,6 +652,10 @@ class TosaErrorValidator:
and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
+ or input_dtype == DType.FP8E4M3
+ and output_dtype != DType.FP16
+ or input_dtype == DType.FP8E5M2
+ and output_dtype != DType.FP16
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -2615,6 +2674,11 @@ class TosaErrorValidator:
DType.FP32,
):
error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,