aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py49
1 files changed, 39 insertions, 10 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b1f8942..a741efb 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1349,29 +1349,58 @@ class TosaArgGen:
arg_list = []
# Enumerate the output types here
- for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ for outDtype in [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.UINT16,
+ ]:
if (
- dtype in [DType.UINT8, DType.INT8]
+ outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
and error_name == ErrorIf.OutputZeroPointNotZero
):
continue
if (
+ outDtype != DType.UINT16
+ and error_name == ErrorIf.U16OutputZeroPointNotValid
+ ) or (
+ inDtype != DType.UINT16
+ and error_name == ErrorIf.U16InputZeroPointNotValid
+ ):
+ # ErrorIfs only valid with UINT16
+ continue
+ if (
inDtype == DType.UINT8
- and dtype != DType.INT8
+ and outDtype not in [DType.INT8, DType.INT16]
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only output dtypes for UINT8 are INT8/INT16, skip all others
+ continue
+ if (
+ inDtype not in [DType.INT8, DType.INT16]
+ and outDtype == DType.UINT8
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only input dtypes for UINT8 are INT8/INT16, skip all others
+ continue
+ if (
+ inDtype == DType.UINT16
+ and outDtype != DType.INT16
and error_name != ErrorIf.WrongOutputType
):
- # The only output dtype for UINT8 is INT8, skip all other combinations
+ # The only output dtype for UINT16 is INT16, skip all others
continue
if (
- inDtype != DType.INT8
- and dtype == DType.UINT8
+ inDtype != DType.INT16
+ and outDtype == DType.UINT16
and error_name != ErrorIf.WrongOutputType
):
- # The only input dtype for UINT8 is INT8, skip all other combinations
+ # The only input dtype for UINT16 is INT16, skip all others
continue
if (
error_name == ErrorIf.WrongOutputType
- and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
+ and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
):
continue
@@ -1403,12 +1432,12 @@ class TosaArgGen:
arg_list.append(
(
"out{}_sc{}_dr{}_pc{}".format(
- DTypeNames[dtype],
+ DTypeNames[outDtype],
int(scale32),
int(double_round),
int(per_channel),
),
- [dtype, scale32, double_round, per_channel],
+ [outDtype, scale32, double_round, per_channel],
)
)