diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 74 |
1 files changed, 63 insertions, 11 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index f9a00f9..a766803 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -120,6 +120,7 @@ class TosaErrorIfArgGen: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.NEAREST and dtype == DType.INT16: incorrect_types = ( @@ -128,6 +129,7 @@ class TosaErrorIfArgGen: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: incorrect_types = ( @@ -136,6 +138,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: incorrect_types = ( @@ -144,6 +147,16 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.FLOAT, + DType.FP16, + ) + elif dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, ) elif dtype == DType.FLOAT: incorrect_types = ( @@ -152,6 +165,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.INT48, + DType.FP16, ) outputDType = testGen.rng.choice(a=incorrect_types) @@ -285,8 +299,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FLOAT]: - outputDType = [DType.BOOL, DType.INT48, DType.FLOAT] + if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -400,6 +414,7 @@ class TosaErrorValidator: and input_dtype == DType.INT16 and output_dtype != DType.INT48 ) + or (input_dtype == DType.FP16 and output_dtype != DType.FP16) or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True @@ -413,19 +428,28 @@ class TosaErrorValidator: if ( (input_dtype == DType.INT8 and output_dtype != DType.INT32) or (input_dtype == DType.INT16 and output_dtype != DType.INT48) + or ( + input_dtype == DType.FP16 + and output_dtype not in (DType.FP16, DType.FLOAT) + ) or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] + input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: - if input_dtype != DType.FLOAT and output_dtype != DType.INT32: + if ( + input_dtype not in (DType.FP16, DType.FLOAT) + and output_dtype != DType.INT32 + ): + error_result = True + elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT: error_result = True @@ -449,17 +473,39 @@ class TosaErrorValidator: or ( input_dtype == DType.INT8 and output_dtype - not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ] ) or ( input_dtype == DType.INT16 and output_dtype - not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT8, + DType.INT32, + DType.FLOAT, + DType.FP16, + ] ) or ( input_dtype == DType.INT32 and output_dtype - not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT8, + DType.INT16, + DType.FLOAT, + DType.FP16, + ] + ) + or ( + input_dtype == DType.FP16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( input_dtype == DType.FLOAT @@ -479,6 +525,8 @@ class TosaErrorValidator: and output_dtype != DType.INT32 or input_dtype == DType.INT16 and output_dtype != DType.INT48 + or input_dtype == DType.FP16 + and output_dtype not in (DType.FP16, DType.FLOAT) or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT ): @@ -2257,12 +2305,13 @@ class TosaInvalidValidator: return ( not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) + and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT] + input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] ) else: # Invalid resize mode @@ -2276,8 +2325,11 @@ class TosaInvalidValidator: input_shape = inputShapes[0] args = kwargs["args"] - strides = args[0] - padding = args[1] + + # MaxPool2D has no accum_dtype arg + stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2) + strides = args[stride_idx] + padding = args[pad_idx] if opName.endswith("pool2d"): # avg_pool2d, max_pool2d @@ -2365,7 +2417,7 @@ class TosaInvalidValidator: @staticmethod def ivNonPositiveOutputShape(**kwargs): args = kwargs["args"] - output_shape = args[2] + output_shape = args[3] if output_shape[1] <= 0 or output_shape[2] <= 0: # Negative output shape return True |