diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 49 |
1 files changed, 39 insertions, 10 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index b1f8942..a741efb 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1349,29 +1349,58 @@ class TosaArgGen: arg_list = [] # Enumerate the output types here - for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]: + for outDtype in [ + DType.UINT8, + DType.INT8, + DType.INT16, + DType.INT32, + DType.UINT16, + ]: if ( - dtype in [DType.UINT8, DType.INT8] + outDtype in [DType.UINT8, DType.INT8, DType.UINT16] and error_name == ErrorIf.OutputZeroPointNotZero ): continue if ( + outDtype != DType.UINT16 + and error_name == ErrorIf.U16OutputZeroPointNotValid + ) or ( + inDtype != DType.UINT16 + and error_name == ErrorIf.U16InputZeroPointNotValid + ): + # ErrorIfs only valid with UINT16 + continue + if ( inDtype == DType.UINT8 - and dtype != DType.INT8 + and outDtype not in [DType.INT8, DType.INT16] + and error_name != ErrorIf.WrongOutputType + ): + # The only output dtypes for UINT8 are INT8/INT16, skip all others + continue + if ( + inDtype not in [DType.INT8, DType.INT16] + and outDtype == DType.UINT8 + and error_name != ErrorIf.WrongOutputType + ): + # The only input dtypes for UINT8 are INT8/INT16, skip all others + continue + if ( + inDtype == DType.UINT16 + and outDtype != DType.INT16 and error_name != ErrorIf.WrongOutputType ): - # The only output dtype for UINT8 is INT8, skip all other combinations + # The only output dtype for UINT16 is INT16, skip all others continue if ( - inDtype != DType.INT8 - and dtype == DType.UINT8 + inDtype != DType.INT16 + and outDtype == DType.UINT16 and error_name != ErrorIf.WrongOutputType ): - # The only input dtype for UINT8 is INT8, skip all other combinations + # The only input dtype for UINT16 is INT16, skip all others continue if ( error_name == ErrorIf.WrongOutputType - and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype) + and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype) ): continue @@ -1403,12 +1432,12 @@ class TosaArgGen: arg_list.append( ( "out{}_sc{}_dr{}_pc{}".format( - DTypeNames[dtype], + DTypeNames[outDtype], int(scale32), int(double_round), int(per_channel), ), - [dtype, scale32, double_round, per_channel], + [outDtype, scale32, double_round, per_channel], ) ) |