aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py35
1 files changed, 30 insertions, 5 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index abe1a97..a850699 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -158,6 +158,15 @@ class TosaErrorIfArgGen:
DType.INT48,
DType.FP32,
)
+ elif dtype == DType.BF16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ )
elif dtype == DType.FP32:
incorrect_types = (
DType.INT4,
@@ -299,8 +308,8 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32]
+ if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
@@ -425,6 +434,7 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
)
or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
+ or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
@@ -442,25 +452,29 @@ class TosaErrorValidator:
input_dtype == DType.FP16
and output_dtype not in (DType.FP16, DType.FP32)
)
+ or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
- input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ input_dtype
+ in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
and output_dtype != DType.INT32
):
error_result = True
elif op["op"] == Op.MUL:
if (
- input_dtype not in (DType.FP16, DType.FP32)
+ input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
and output_dtype != DType.INT32
):
error_result = True
elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
error_result = True
+ elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
+ error_result = True
elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
error_result = True
@@ -489,6 +503,7 @@ class TosaErrorValidator:
DType.INT32,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -500,6 +515,7 @@ class TosaErrorValidator:
DType.INT32,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -511,6 +527,7 @@ class TosaErrorValidator:
DType.INT16,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -518,6 +535,10 @@ class TosaErrorValidator:
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
or (
+ input_dtype == DType.BF16
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ )
+ or (
input_dtype == DType.FP32
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
@@ -537,6 +558,8 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
or input_dtype == DType.FP16
and output_dtype not in (DType.FP16, DType.FP32)
+ or input_dtype == DType.BF16
+ and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
):
@@ -2316,12 +2339,14 @@ class TosaInvalidValidator:
not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
+ and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ input_dtype
+ not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
)
else:
# Invalid resize mode