diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 35 |
1 files changed, 30 insertions, 5 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index abe1a97..a850699 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -158,6 +158,15 @@ class TosaErrorIfArgGen: DType.INT48, DType.FP32, ) + elif dtype == DType.BF16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + ) elif dtype == DType.FP32: incorrect_types = ( DType.INT4, @@ -299,8 +308,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32] + if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -425,6 +434,7 @@ class TosaErrorValidator: and output_dtype != DType.INT48 ) or (input_dtype == DType.FP16 and output_dtype != DType.FP16) + or (input_dtype == DType.BF16 and output_dtype != DType.BF16) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True @@ -442,25 +452,29 @@ class TosaErrorValidator: input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) ) + or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: if ( - input_dtype not in (DType.FP16, DType.FP32) + input_dtype not in (DType.FP16, DType.BF16, DType.FP32) and output_dtype != DType.INT32 ): error_result = True elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True + elif input_dtype == DType.BF16 and output_dtype != DType.BF16: + error_result = True elif input_dtype == DType.FP32 and output_dtype != DType.FP32: error_result = True @@ -489,6 +503,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -500,6 +515,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -511,6 +527,7 @@ class TosaErrorValidator: DType.INT16, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -518,6 +535,10 @@ class TosaErrorValidator: and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( + input_dtype == DType.BF16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + ) + or ( input_dtype == DType.FP32 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) @@ -537,6 +558,8 @@ class TosaErrorValidator: and output_dtype != DType.INT48 or input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) + or input_dtype == DType.BF16 + and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 ): @@ -2316,12 +2339,14 @@ class TosaInvalidValidator: not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) + and not (input_dtype == DType.BF16 and output_dtype == DType.BF16) and not (input_dtype == DType.FP32 and output_dtype == DType.FP32) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] ) else: # Invalid resize mode |