diff options
author | Won Jeon <won.jeon@arm.com> | 2024-02-06 18:37:00 +0000 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2024-02-21 19:38:55 +0000 |
commit | 2c34b4616a10539211e7006bc43f3c71e86c30bb (patch) | |
tree | aa4043a610ecd4c6d35b876cfb013dbe7dd0ab01 /verif/generator/tosa_utils.py | |
parent | 587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (diff) | |
download | reference_model-2c34b4616a10539211e7006bc43f3c71e86c30bb.tar.gz |
Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 76e7388..31a0ff0 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = { DType.FP16: {"str": "f16", "width": 16, "json": "FP16"}, DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"}, DType.FP32: {"str": "f32", "width": 32, "json": "FP32"}, + DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"}, + DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"}, } @@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT32, DType.INT48, ) + elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.BF16, + ) else: # Assume all types but the input type are incorrect incorrect_types = list(usableDTypes(excludes=(input_dtype,))) @@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f): return f32_bits[16:] == "0" * 16 +def float32_is_valid_float8(f): + """Return True if float value is valid float8.""" + f32_bits = get_float32_bitstring(f) + return f32_bits[8:] == "0" * 24 + + def get_float32_bitstring(f): """Return a big-endian string of bits representing a 32 bit float.""" f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0] @@ -232,6 +250,30 @@ def float32_to_bfloat16(f): return struct.unpack("@f", fp_bytes)[0] # native byteorder +def float32_to_fp8e4m3(f): + """Turns fp32 value into fp8e4m3""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] # native byteorder + + +def float32_to_fp8e5m2(f): + """Turns fp32 value into fp8e5m2""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] + + vect_f32_to_bf16 = np.vectorize( float32_to_bfloat16, otypes=(np.float32,) ) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e4m3 = np.vectorize( + float32_to_fp8e4m3, otypes=(np.float32,) +) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e5m2 = np.vectorize( + float32_to_fp8e5m2, otypes=(np.float32,) +) # Numpy vectorize: applies function to vector faster than looping |