aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 76e7388..31a0ff0 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = {
DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
+ DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
+ DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
}
@@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT32,
DType.INT48,
)
+ elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ )
else:
# Assume all types but the input type are incorrect
incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
@@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f):
return f32_bits[16:] == "0" * 16
+def float32_is_valid_float8(f):
+ """Return True if float value is valid float8."""
+ f32_bits = get_float32_bitstring(f)
+ return f32_bits[8:] == "0" * 24
+
+
def get_float32_bitstring(f):
"""Return a big-endian string of bits representing a 32 bit float."""
f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
@@ -232,6 +250,30 @@ def float32_to_bfloat16(f):
return struct.unpack("@f", fp_bytes)[0] # native byteorder
+def float32_to_fp8e4m3(f):
+ """Turns fp32 value into fp8e4m3"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0] # native byteorder
+
+
+def float32_to_fp8e5m2(f):
+ """Turns fp32 value into fp8e5m2"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0]
+
+
vect_f32_to_bf16 = np.vectorize(
float32_to_bfloat16, otypes=(np.float32,)
) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e4m3 = np.vectorize(
+ float32_to_fp8e4m3, otypes=(np.float32,)
+) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e5m2 = np.vectorize(
+ float32_to_fp8e5m2, otypes=(np.float32,)
+) # Numpy vectorize: applies function to vector faster than looping