aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py56
1 files changed, 33 insertions, 23 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a766803..abe1a97 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -119,7 +119,7 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
@@ -128,7 +128,7 @@ class TosaErrorIfArgGen:
DType.INT8,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
@@ -137,7 +137,7 @@ class TosaErrorIfArgGen:
DType.INT8,
DType.INT16,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
@@ -146,7 +146,7 @@ class TosaErrorIfArgGen:
DType.INT8,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif dtype == DType.FP16:
@@ -156,9 +156,9 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
)
- elif dtype == DType.FLOAT:
+ elif dtype == DType.FP32:
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -299,8 +299,8 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT]
+ if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
@@ -366,6 +366,16 @@ class TosaErrorValidator:
}
wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
+ # Turn the wrong dtypes into required list of types
+ if op["op"] in [
+ Op.FULLY_CONNECTED,
+ Op.CONV2D,
+ Op.CONV3D,
+ Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
+ ]:
+ wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
+
if op["op"] == Op.CLAMP:
wrong_input_dtypes.remove(DType.INT48)
@@ -415,7 +425,7 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
)
or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
- or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
@@ -430,28 +440,28 @@ class TosaErrorValidator:
or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
or (
input_dtype == DType.FP16
- and output_dtype not in (DType.FP16, DType.FLOAT)
+ and output_dtype not in (DType.FP16, DType.FP32)
)
- or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
- input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
+ input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
and output_dtype != DType.INT32
):
error_result = True
elif op["op"] == Op.MUL:
if (
- input_dtype not in (DType.FP16, DType.FLOAT)
+ input_dtype not in (DType.FP16, DType.FP32)
and output_dtype != DType.INT32
):
error_result = True
elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
error_result = True
- elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
+ elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
error_result = True
elif op["op"] == Op.TABLE:
@@ -477,7 +487,7 @@ class TosaErrorValidator:
DType.BOOL,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
)
@@ -488,7 +498,7 @@ class TosaErrorValidator:
DType.BOOL,
DType.INT8,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
)
@@ -499,7 +509,7 @@ class TosaErrorValidator:
DType.BOOL,
DType.INT8,
DType.INT16,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
)
@@ -508,7 +518,7 @@ class TosaErrorValidator:
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
or (
- input_dtype == DType.FLOAT
+ input_dtype == DType.FP32
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
):
@@ -526,9 +536,9 @@ class TosaErrorValidator:
or input_dtype == DType.INT16
and output_dtype != DType.INT48
or input_dtype == DType.FP16
- and output_dtype not in (DType.FP16, DType.FLOAT)
- or input_dtype == DType.FLOAT
- and output_dtype != DType.FLOAT
+ and output_dtype not in (DType.FP16, DType.FP32)
+ or input_dtype == DType.FP32
+ and output_dtype != DType.FP32
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -2306,12 +2316,12 @@ class TosaInvalidValidator:
not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
- and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
+ and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
+ input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
)
else:
# Invalid resize mode