diff options
Diffstat (limited to 'verif/generator')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 60 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 72 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 95 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 42 |
4 files changed, 241 insertions, 28 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 7ec0cfe..d0b9eb9 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -641,6 +641,8 @@ class TosaTensorValuesGen: DType.FP32: (1 << 128) - (1 << (127 - 23)), DType.FP16: (1 << 16) - (1 << (15 - 10)), DType.BF16: (1 << 128) - (1 << (127 - 7)), + DType.FP8E4M3: 448, + DType.FP8E5M2: 57344, } # Default lowest normal values for random numbers @@ -648,6 +650,8 @@ class TosaTensorValuesGen: DType.FP32: np.exp2(-126), DType.FP16: np.exp2(-14), DType.BF16: np.exp2(-126), + DType.FP8E4M3: np.exp2(-9), + DType.FP8E5M2: np.exp2(-16), } @staticmethod @@ -715,6 +719,8 @@ class TosaTensorValuesGen: DType.FP16, DType.FP32, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ): # Change from inclusive to exclusive range data_range = (data_range[0], data_range[1] + 1) @@ -1734,7 +1740,13 @@ class TosaArgGen: and "data_gen" in testGen.TOSA_OP_LIST[opName] and gtu.dtypeIsSupportedByCompliance(dtype) ): - if dtype in [DType.FP16, DType.FP32, DType.BF16]: + if dtype in [ + DType.FP16, + DType.FP32, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ]: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] else: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] @@ -2140,6 +2152,8 @@ class TosaArgGen: accum_dtypes = [DType.FP32] elif dtype == DType.FP32: accum_dtypes = [DType.FP32] + elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2: + accum_dtypes = [DType.FP16] elif error_name is None: assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" @@ -2350,7 +2364,13 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype in (DType.FP16, DType.BF16, DType.FP32): + elif dtype in ( + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: @@ -2468,6 +2488,8 @@ class TosaArgGen: accum_dtypes = [DType.FP16, DType.FP32] elif dtype == DType.BF16 or dtype == DType.FP32: accum_dtypes = [DType.FP32] + elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2: + accum_dtypes = [DType.FP16] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" else: @@ -2646,11 +2668,35 @@ class TosaArgGen: elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] elif inDtype == DType.BF16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] elif inDtype == DType.FP32: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]: + dtypeList = [DType.FP16, DType.BF16, DType.FP32] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] @@ -3232,6 +3278,10 @@ class TosaArgGen: outputDTypeList = [DType.BF16] elif dtype == DType.FP32: outputDTypeList = [DType.FP32] + elif dtype == DType.FP8E4M3: + outputDTypeList = [DType.FP8E4M3] + elif dtype == DType.FP8E5M2: + outputDTypeList = [DType.FP8E5M2] elif error_name == ErrorIf.WrongInputType: # If an incorrect input type is used then we set a 'correct' # output type to avoid other errors diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 9a88acb..7a4d0d6 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -325,12 +325,32 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP32]: + # if input_dtype in [DType.BOOL, DType.FP32]: + # outputDType = [DType.BOOL, DType.INT48, DType.FP32] + if input_dtype in [DType.BOOL]: + outputDType = [ + DType.BOOL, + DType.INT48, + DType.FP32, + DType.FP16, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + elif input_dtype in [DType.FP32]: outputDType = [DType.BOOL, DType.INT48, DType.FP32] elif input_dtype in [DType.FP16, DType.BF16]: outputDType = [DType.BOOL, DType.INT48] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] + elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]: + outputDType = [ + DType.BOOL, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ] else: assert False, f"input_dtype ({input_dtype}) not supported" return outputDType @@ -476,13 +496,23 @@ class TosaErrorValidator: ) or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) + or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16) + or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16) ): error_result = True elif op["op"] == Op.ARGMAX: if ( input_dtype - in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] + in [ + DType.INT8, + DType.INT16, + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] and output_dtype != DType.INT32 ): error_result = True @@ -555,12 +585,26 @@ class TosaErrorValidator: or ( input_dtype == DType.FP16 and output_dtype - not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + not in [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] ) or ( input_dtype == DType.BF16 and output_dtype - not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + not in [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] ) or ( input_dtype == DType.FP32 @@ -571,6 +615,17 @@ class TosaErrorValidator: DType.INT32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + ) + or ( + input_dtype in [DType.FP8E4M3, DType.FP8E5M2] + and output_dtype + not in [ + DType.FP16, + DType.BF16, + DType.FP32, ] ) ): @@ -597,6 +652,10 @@ class TosaErrorValidator: and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 + or input_dtype == DType.FP8E4M3 + and output_dtype != DType.FP16 + or input_dtype == DType.FP8E5M2 + and output_dtype != DType.FP16 ): error_result = True # invalid input types are ignored, to avoid reporting multiple errors @@ -2615,6 +2674,11 @@ class TosaErrorValidator: DType.FP32, ): error_result = True + elif ( + input_dtype in (DType.FP8E4M3, DType.FP8E5M2) + and accum_dtype != DType.FP16 + ): + error_result = True info_dict = { "error_name": error_name, diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 4ead982..bc931dc 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -76,7 +76,7 @@ class TosaTestGen: return tuple(sorted(vals)) self.random_float_range = {} - for dtype in (DType.FP32, DType.FP16, DType.BF16): + for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): self.random_float_range[dtype] = convertFPRange( args.tensor_fp_value_range, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], @@ -152,7 +152,7 @@ class TosaTestGen: # Returns dtype value range boundaries (low, high) # The high boundary is excluded in the range # unless high_inclusive is True - if dtype in (DType.FP32, DType.FP16, DType.BF16): + if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): return self.random_float_range[dtype] elif dtype == DType.BOOL: rng = (0, 2) @@ -197,7 +197,13 @@ class TosaTestGen: return np.uint8(self.rng.integers(low=low, high=high, size=shape)) elif dtype in (DType.INT48, DType.SHAPE): return np.int64(self.rng.integers(low=low, high=high, size=shape)) - elif dtype in (DType.FP16, DType.BF16, DType.FP32): + elif dtype in ( + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ): f_tensor = self.rng.uniform(low=low, high=high, size=shape) if dtype == DType.FP16: @@ -207,6 +213,10 @@ class TosaTestGen: if dtype == DType.BF16: # Floor the last 16 bits of each f32 value return np.float32(gtu.vect_f32_to_bf16(f32_tensor)) + elif dtype == DType.FP8E4M3: + return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor)) + elif dtype == DType.FP8E5M2: + return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor)) else: return f32_tensor else: @@ -266,6 +276,12 @@ class TosaTestGen: elif dtype == DType.BF16: rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) return gtu.vect_f32_to_bf16(rand_f32) + elif dtype == DType.FP8E4M3: + rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) + return gtu.vect_f32_to_fp8e4m3(rand_f32) + elif dtype == DType.FP8E5M2: + rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) + return gtu.vect_f32_to_fp8e5m2(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) elif dtype == DType.INT48 or dtype == DType.SHAPE: @@ -1408,8 +1424,11 @@ class TosaTestGen: max_val = max_val.astype(np.float32) attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val) - else: + elif a.dtype in (DType.INT8, DType.INT16): attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0) + else: + # to avoid internal error for incorrect input types + attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -3190,7 +3209,13 @@ class TosaTestGen: ] TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] + TYPE_NARROW_INT_FP = [ + DType.INT8, + DType.INT16, + DType.FP16, + DType.BF16, + DType.FP32, + ] # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ @@ -3201,6 +3226,8 @@ class TosaTestGen: [DType.FP16, DType.FP16, DType.FP32], [DType.BF16, DType.BF16, DType.FP32], [DType.FP32, DType.FP32, DType.FP32], + [DType.FP8E4M3, DType.FP8E4M3, DType.FP16], + [DType.FP8E5M2, DType.FP8E5M2, DType.FP16], ] DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK) @@ -3217,7 +3244,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, @@ -3244,7 +3271,7 @@ class TosaTestGen: TosaArgGen.agPooling, ), "qgen": TosaQuantGen.qgUnary, - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, @@ -3402,7 +3429,7 @@ class TosaTestGen: TosaArgGen.agMatMul, ), "qgen": TosaQuantGen.qgMatmul, - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, @@ -3425,7 +3452,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agPooling, ), - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, @@ -4389,7 +4416,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgConcat, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -4413,7 +4440,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgPad, TosaArgGen.agPad, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero, @@ -4437,7 +4464,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -4456,7 +4483,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReshape, TosaArgGen.agReshape, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, @@ -4477,7 +4504,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, @@ -4500,7 +4527,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgSlice, TosaArgGen.agSlice, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( # TODO Turn off these error categories for now as the reference # model cannot allocate memory space for empty tensor. We probably @@ -4532,7 +4559,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgTile, TosaArgGen.agTile, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -4555,7 +4582,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agTranspose, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, @@ -4581,7 +4608,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agNone, ), - "types": TYPE_FIB + [DType.INT48], + "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2], "data_gen": { "fp": (gtu.DataGenType.PSEUDO_RANDOM,), }, @@ -4618,6 +4645,8 @@ class TosaTestGen: DType.FP16, DType.BF16, DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, @@ -4640,7 +4669,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgScatter, TosaArgGen.agNone, ), - "types": TYPE_INT_FP, + "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -4709,6 +4738,8 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.BOOL, + DType.FP8E4M3, + DType.FP8E5M2, ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, @@ -5141,6 +5172,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -5194,6 +5227,8 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: excludes = [DType.FP16, DType.FP32] + if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]: + excludes = [DType.FP16] else: excludes = [out_dtype] wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) @@ -5344,6 +5379,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -5383,6 +5420,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -5393,6 +5432,20 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ) + elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ) elif ( a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16 @@ -5403,6 +5456,8 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, + DType.FP8E4M3, + DType.FP8E5M2, ) out_dtype = rng.choice(a=incorrect_types) elif error_name == ErrorIf.WrongInputType: @@ -5669,6 +5724,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 76e7388..31a0ff0 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = { DType.FP16: {"str": "f16", "width": 16, "json": "FP16"}, DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"}, DType.FP32: {"str": "f32", "width": 32, "json": "FP32"}, + DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"}, + DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"}, } @@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT32, DType.INT48, ) + elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.BF16, + ) else: # Assume all types but the input type are incorrect incorrect_types = list(usableDTypes(excludes=(input_dtype,))) @@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f): return f32_bits[16:] == "0" * 16 +def float32_is_valid_float8(f): + """Return True if float value is valid float8.""" + f32_bits = get_float32_bitstring(f) + return f32_bits[8:] == "0" * 24 + + def get_float32_bitstring(f): """Return a big-endian string of bits representing a 32 bit float.""" f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0] @@ -232,6 +250,30 @@ def float32_to_bfloat16(f): return struct.unpack("@f", fp_bytes)[0] # native byteorder +def float32_to_fp8e4m3(f): + """Turns fp32 value into fp8e4m3""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] # native byteorder + + +def float32_to_fp8e5m2(f): + """Turns fp32 value into fp8e5m2""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] + + vect_f32_to_bf16 = np.vectorize( float32_to_bfloat16, otypes=(np.float32,) ) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e4m3 = np.vectorize( + float32_to_fp8e4m3, otypes=(np.float32,) +) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e5m2 = np.vectorize( + float32_to_fp8e5m2, otypes=(np.float32,) +) # Numpy vectorize: applies function to vector faster than looping |