aboutsummaryrefslogtreecommitdiff
path: root/verif/generator
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator')
-rw-r--r--verif/generator/tosa_arg_gen.py33
-rw-r--r--verif/generator/tosa_error_if.py23
2 files changed, 44 insertions, 12 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index fed91f6..05a7d2b 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1445,19 +1445,40 @@ class TosaArgGen:
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.BF16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.FP32:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 40c5d13..93f975d 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -314,12 +314,14 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
+ if input_dtype in [DType.BOOL, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ elif input_dtype in [DType.FP16, DType.BF16]:
+ outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
- assert True, f"input_dtype ({input_dtype}) not supported"
+ assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -538,15 +540,24 @@ class TosaErrorValidator:
)
or (
input_dtype == DType.FP16
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
)
or (
input_dtype == DType.BF16
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
)
or (
input_dtype == DType.FP32
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ ]
)
):
error_result = True