diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 60 |
1 files changed, 55 insertions, 5 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 7ec0cfe..d0b9eb9 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -641,6 +641,8 @@ class TosaTensorValuesGen: DType.FP32: (1 << 128) - (1 << (127 - 23)), DType.FP16: (1 << 16) - (1 << (15 - 10)), DType.BF16: (1 << 128) - (1 << (127 - 7)), + DType.FP8E4M3: 448, + DType.FP8E5M2: 57344, } # Default lowest normal values for random numbers @@ -648,6 +650,8 @@ class TosaTensorValuesGen: DType.FP32: np.exp2(-126), DType.FP16: np.exp2(-14), DType.BF16: np.exp2(-126), + DType.FP8E4M3: np.exp2(-9), + DType.FP8E5M2: np.exp2(-16), } @staticmethod @@ -715,6 +719,8 @@ class TosaTensorValuesGen: DType.FP16, DType.FP32, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ): # Change from inclusive to exclusive range data_range = (data_range[0], data_range[1] + 1) @@ -1734,7 +1740,13 @@ class TosaArgGen: and "data_gen" in testGen.TOSA_OP_LIST[opName] and gtu.dtypeIsSupportedByCompliance(dtype) ): - if dtype in [DType.FP16, DType.FP32, DType.BF16]: + if dtype in [ + DType.FP16, + DType.FP32, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ]: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] else: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] @@ -2140,6 +2152,8 @@ class TosaArgGen: accum_dtypes = [DType.FP32] elif dtype == DType.FP32: accum_dtypes = [DType.FP32] + elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2: + accum_dtypes = [DType.FP16] elif error_name is None: assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" @@ -2350,7 +2364,13 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype in (DType.FP16, DType.BF16, DType.FP32): + elif dtype in ( + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: @@ -2468,6 +2488,8 @@ class TosaArgGen: accum_dtypes = [DType.FP16, DType.FP32] elif dtype == DType.BF16 or dtype == DType.FP32: accum_dtypes = [DType.FP32] + elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2: + accum_dtypes = [DType.FP16] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" else: @@ -2646,11 +2668,35 @@ class TosaArgGen: elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] elif inDtype == DType.BF16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] elif inDtype == DType.FP32: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]: + dtypeList = [DType.FP16, DType.BF16, DType.FP32] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] @@ -3232,6 +3278,10 @@ class TosaArgGen: outputDTypeList = [DType.BF16] elif dtype == DType.FP32: outputDTypeList = [DType.FP32] + elif dtype == DType.FP8E4M3: + outputDTypeList = [DType.FP8E4M3] + elif dtype == DType.FP8E5M2: + outputDTypeList = [DType.FP8E5M2] elif error_name == ErrorIf.WrongInputType: # If an incorrect input type is used then we set a 'correct' # output type to avoid other errors |