aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_error_if.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r--verif/generator/tosa_error_if.py74
1 files changed, 63 insertions, 11 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index f9a00f9..a766803 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -120,6 +120,7 @@ class TosaErrorIfArgGen:
DType.INT32,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
)
elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
incorrect_types = (
@@ -128,6 +129,7 @@ class TosaErrorIfArgGen:
DType.INT32,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
incorrect_types = (
@@ -136,6 +138,7 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
incorrect_types = (
@@ -144,6 +147,16 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.FLOAT,
+ DType.FP16,
+ )
+ elif dtype == DType.FP16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
)
elif dtype == DType.FLOAT:
incorrect_types = (
@@ -152,6 +165,7 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP16,
)
outputDType = testGen.rng.choice(a=incorrect_types)
@@ -285,8 +299,8 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FLOAT]:
- outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
+ if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
@@ -400,6 +414,7 @@ class TosaErrorValidator:
and input_dtype == DType.INT16
and output_dtype != DType.INT48
)
+ or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
@@ -413,19 +428,28 @@ class TosaErrorValidator:
if (
(input_dtype == DType.INT8 and output_dtype != DType.INT32)
or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
+ or (
+ input_dtype == DType.FP16
+ and output_dtype not in (DType.FP16, DType.FLOAT)
+ )
or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
- input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
+ input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
and output_dtype != DType.INT32
):
error_result = True
elif op["op"] == Op.MUL:
- if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
+ if (
+ input_dtype not in (DType.FP16, DType.FLOAT)
+ and output_dtype != DType.INT32
+ ):
+ error_result = True
+ elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
error_result = True
elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
error_result = True
@@ -449,17 +473,39 @@ class TosaErrorValidator:
or (
input_dtype == DType.INT8
and output_dtype
- not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ not in [
+ DType.BOOL,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ ]
)
or (
input_dtype == DType.INT16
and output_dtype
- not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ not in [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ ]
)
or (
input_dtype == DType.INT32
and output_dtype
- not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ not in [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.FLOAT,
+ DType.FP16,
+ ]
+ )
+ or (
+ input_dtype == DType.FP16
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
or (
input_dtype == DType.FLOAT
@@ -479,6 +525,8 @@ class TosaErrorValidator:
and output_dtype != DType.INT32
or input_dtype == DType.INT16
and output_dtype != DType.INT48
+ or input_dtype == DType.FP16
+ and output_dtype not in (DType.FP16, DType.FLOAT)
or input_dtype == DType.FLOAT
and output_dtype != DType.FLOAT
):
@@ -2257,12 +2305,13 @@ class TosaInvalidValidator:
return (
not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
+ and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT]
+ input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
)
else:
# Invalid resize mode
@@ -2276,8 +2325,11 @@ class TosaInvalidValidator:
input_shape = inputShapes[0]
args = kwargs["args"]
- strides = args[0]
- padding = args[1]
+
+ # MaxPool2D has no accum_dtype arg
+ stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
+ strides = args[stride_idx]
+ padding = args[pad_idx]
if opName.endswith("pool2d"):
# avg_pool2d, max_pool2d
@@ -2365,7 +2417,7 @@ class TosaInvalidValidator:
@staticmethod
def ivNonPositiveOutputShape(**kwargs):
args = kwargs["args"]
- output_shape = args[2]
+ output_shape = args[3]
if output_shape[1] <= 0 or output_shape[2] <= 0:
# Negative output shape
return True