diff options
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 76e7388..31a0ff0 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = { DType.FP16: {"str": "f16", "width": 16, "json": "FP16"}, DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"}, DType.FP32: {"str": "f32", "width": 32, "json": "FP32"}, + DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"}, + DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"}, } @@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT32, DType.INT48, ) + elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.BF16, + ) else: # Assume all types but the input type are incorrect incorrect_types = list(usableDTypes(excludes=(input_dtype,))) @@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f): return f32_bits[16:] == "0" * 16 +def float32_is_valid_float8(f): + """Return True if float value is valid float8.""" + f32_bits = get_float32_bitstring(f) + return f32_bits[8:] == "0" * 24 + + def get_float32_bitstring(f): """Return a big-endian string of bits representing a 32 bit float.""" f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0] @@ -232,6 +250,30 @@ def float32_to_bfloat16(f): return struct.unpack("@f", fp_bytes)[0] # native byteorder +def float32_to_fp8e4m3(f): + """Turns fp32 value into fp8e4m3""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] # native byteorder + + +def float32_to_fp8e5m2(f): + """Turns fp32 value into fp8e5m2""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] + + vect_f32_to_bf16 = np.vectorize( float32_to_bfloat16, otypes=(np.float32,) ) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e4m3 = np.vectorize( + float32_to_fp8e4m3, otypes=(np.float32,) +) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e5m2 = np.vectorize( + float32_to_fp8e5m2, otypes=(np.float32,) +) # Numpy vectorize: applies function to vector faster than looping |