diff options
Diffstat (limited to 'verif/generator/tosa_error_if.py')
-rw-r--r-- | verif/generator/tosa_error_if.py | 72 |
1 files changed, 68 insertions, 4 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 9a88acb..7a4d0d6 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -325,12 +325,32 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP32]: + # if input_dtype in [DType.BOOL, DType.FP32]: + # outputDType = [DType.BOOL, DType.INT48, DType.FP32] + if input_dtype in [DType.BOOL]: + outputDType = [ + DType.BOOL, + DType.INT48, + DType.FP32, + DType.FP16, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + elif input_dtype in [DType.FP32]: outputDType = [DType.BOOL, DType.INT48, DType.FP32] elif input_dtype in [DType.FP16, DType.BF16]: outputDType = [DType.BOOL, DType.INT48] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] + elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]: + outputDType = [ + DType.BOOL, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ] else: assert False, f"input_dtype ({input_dtype}) not supported" return outputDType @@ -476,13 +496,23 @@ class TosaErrorValidator: ) or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) + or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16) + or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16) ): error_result = True elif op["op"] == Op.ARGMAX: if ( input_dtype - in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] + in [ + DType.INT8, + DType.INT16, + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] and output_dtype != DType.INT32 ): error_result = True @@ -555,12 +585,26 @@ class TosaErrorValidator: or ( input_dtype == DType.FP16 and output_dtype - not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + not in [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] ) or ( input_dtype == DType.BF16 and output_dtype - not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + not in [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] ) or ( input_dtype == DType.FP32 @@ -571,6 +615,17 @@ class TosaErrorValidator: DType.INT32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + ) + or ( + input_dtype in [DType.FP8E4M3, DType.FP8E5M2] + and output_dtype + not in [ + DType.FP16, + DType.BF16, + DType.FP32, ] ) ): @@ -597,6 +652,10 @@ class TosaErrorValidator: and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 + or input_dtype == DType.FP8E4M3 + and output_dtype != DType.FP16 + or input_dtype == DType.FP8E5M2 + and output_dtype != DType.FP16 ): error_result = True # invalid input types are ignored, to avoid reporting multiple errors @@ -2615,6 +2674,11 @@ class TosaErrorValidator: DType.FP32, ): error_result = True + elif ( + input_dtype in (DType.FP8E4M3, DType.FP8E5M2) + and accum_dtype != DType.FP16 + ): + error_result = True info_dict = { "error_name": error_name, |