aboutsummaryrefslogtreecommitdiff
path: root/verif/generator
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator')
-rw-r--r--verif/generator/tosa_arg_gen.py60
-rw-r--r--verif/generator/tosa_error_if.py72
-rw-r--r--verif/generator/tosa_test_gen.py95
-rw-r--r--verif/generator/tosa_utils.py42
4 files changed, 241 insertions, 28 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 7ec0cfe..d0b9eb9 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -641,6 +641,8 @@ class TosaTensorValuesGen:
DType.FP32: (1 << 128) - (1 << (127 - 23)),
DType.FP16: (1 << 16) - (1 << (15 - 10)),
DType.BF16: (1 << 128) - (1 << (127 - 7)),
+ DType.FP8E4M3: 448,
+ DType.FP8E5M2: 57344,
}
# Default lowest normal values for random numbers
@@ -648,6 +650,8 @@ class TosaTensorValuesGen:
DType.FP32: np.exp2(-126),
DType.FP16: np.exp2(-14),
DType.BF16: np.exp2(-126),
+ DType.FP8E4M3: np.exp2(-9),
+ DType.FP8E5M2: np.exp2(-16),
}
@staticmethod
@@ -715,6 +719,8 @@ class TosaTensorValuesGen:
DType.FP16,
DType.FP32,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
):
# Change from inclusive to exclusive range
data_range = (data_range[0], data_range[1] + 1)
@@ -1734,7 +1740,13 @@ class TosaArgGen:
and "data_gen" in testGen.TOSA_OP_LIST[opName]
and gtu.dtypeIsSupportedByCompliance(dtype)
):
- if dtype in [DType.FP16, DType.FP32, DType.BF16]:
+ if dtype in [
+ DType.FP16,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
else:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
@@ -2140,6 +2152,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP32]
elif dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -2350,7 +2364,13 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -2468,6 +2488,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP16, DType.FP32]
elif dtype == DType.BF16 or dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -2646,11 +2668,35 @@ class TosaArgGen:
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.BF16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.FP32:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ dtypeList = [DType.FP16, DType.BF16, DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
@@ -3232,6 +3278,10 @@ class TosaArgGen:
outputDTypeList = [DType.BF16]
elif dtype == DType.FP32:
outputDTypeList = [DType.FP32]
+ elif dtype == DType.FP8E4M3:
+ outputDTypeList = [DType.FP8E4M3]
+ elif dtype == DType.FP8E5M2:
+ outputDTypeList = [DType.FP8E5M2]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 9a88acb..7a4d0d6 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -325,12 +325,32 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP32]:
+ # if input_dtype in [DType.BOOL, DType.FP32]:
+ # outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ if input_dtype in [DType.BOOL]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT48,
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif input_dtype in [DType.FP32]:
outputDType = [DType.BOOL, DType.INT48, DType.FP32]
elif input_dtype in [DType.FP16, DType.BF16]:
outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
+ elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ ]
else:
assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -476,13 +496,23 @@ class TosaErrorValidator:
)
or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
+ or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
+ or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
input_dtype
- in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ in [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
and output_dtype != DType.INT32
):
error_result = True
@@ -555,12 +585,26 @@ class TosaErrorValidator:
or (
input_dtype == DType.FP16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.BF16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.FP32
@@ -571,6 +615,17 @@ class TosaErrorValidator:
DType.INT32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ )
+ or (
+ input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
+ and output_dtype
+ not in [
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
]
)
):
@@ -597,6 +652,10 @@ class TosaErrorValidator:
and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
+ or input_dtype == DType.FP8E4M3
+ and output_dtype != DType.FP16
+ or input_dtype == DType.FP8E5M2
+ and output_dtype != DType.FP16
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -2615,6 +2674,11 @@ class TosaErrorValidator:
DType.FP32,
):
error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 4ead982..bc931dc 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -76,7 +76,7 @@ class TosaTestGen:
return tuple(sorted(vals))
self.random_float_range = {}
- for dtype in (DType.FP32, DType.FP16, DType.BF16):
+ for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
self.random_float_range[dtype] = convertFPRange(
args.tensor_fp_value_range,
TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
@@ -152,7 +152,7 @@ class TosaTestGen:
# Returns dtype value range boundaries (low, high)
# The high boundary is excluded in the range
# unless high_inclusive is True
- if dtype in (DType.FP32, DType.FP16, DType.BF16):
+ if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
return self.random_float_range[dtype]
elif dtype == DType.BOOL:
rng = (0, 2)
@@ -197,7 +197,13 @@ class TosaTestGen:
return np.uint8(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (DType.INT48, DType.SHAPE):
return np.int64(self.rng.integers(low=low, high=high, size=shape))
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
f_tensor = self.rng.uniform(low=low, high=high, size=shape)
if dtype == DType.FP16:
@@ -207,6 +213,10 @@ class TosaTestGen:
if dtype == DType.BF16:
# Floor the last 16 bits of each f32 value
return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+ elif dtype == DType.FP8E4M3:
+ return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
+ elif dtype == DType.FP8E5M2:
+ return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
else:
return f32_tensor
else:
@@ -266,6 +276,12 @@ class TosaTestGen:
elif dtype == DType.BF16:
rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
return gtu.vect_f32_to_bf16(rand_f32)
+ elif dtype == DType.FP8E4M3:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e4m3(rand_f32)
+ elif dtype == DType.FP8E5M2:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e5m2(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
elif dtype == DType.INT48 or dtype == DType.SHAPE:
@@ -1408,8 +1424,11 @@ class TosaTestGen:
max_val = max_val.astype(np.float32)
attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
- else:
+ elif a.dtype in (DType.INT8, DType.INT16):
attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
+ else:
+ # to avoid internal error for incorrect input types
+ attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -3190,7 +3209,13 @@ class TosaTestGen:
]
TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ TYPE_NARROW_INT_FP = [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
@@ -3201,6 +3226,8 @@ class TosaTestGen:
[DType.FP16, DType.FP16, DType.FP32],
[DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
+ [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
+ [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
]
DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
@@ -3217,7 +3244,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -3244,7 +3271,7 @@ class TosaTestGen:
TosaArgGen.agPooling,
),
"qgen": TosaQuantGen.qgUnary,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -3402,7 +3429,7 @@ class TosaTestGen:
TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWrongRank,
@@ -3425,7 +3452,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -4389,7 +4416,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgConcat,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4413,7 +4440,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgPad,
TosaArgGen.agPad,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evPadSmallerZero,
@@ -4437,7 +4464,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4456,7 +4483,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReshape,
TosaArgGen.agReshape,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evTensorSizeInputOutputMismatch,
TosaErrorValidator.evWrongInputType,
@@ -4477,7 +4504,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -4500,7 +4527,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgSlice,
TosaArgGen.agSlice,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
# TODO Turn off these error categories for now as the reference
# model cannot allocate memory space for empty tensor. We probably
@@ -4532,7 +4559,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgTile,
TosaArgGen.agTile,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4555,7 +4582,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTranspose,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evIndexOutsideBounds,
TosaErrorValidator.evIndexUsedTwice,
@@ -4581,7 +4608,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
- "types": TYPE_FIB + [DType.INT48],
+ "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
"data_gen": {
"fp": (gtu.DataGenType.PSEUDO_RANDOM,),
},
@@ -4618,6 +4645,8 @@ class TosaTestGen:
DType.FP16,
DType.BF16,
DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -4640,7 +4669,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgScatter,
TosaArgGen.agNone,
),
- "types": TYPE_INT_FP,
+ "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4709,6 +4738,8 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.BOOL,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -5141,6 +5172,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -5194,6 +5227,8 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
+ if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ excludes = [DType.FP16]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
@@ -5344,6 +5379,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -5383,6 +5420,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -5393,6 +5432,20 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ )
+ elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif (
a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
@@ -5403,6 +5456,8 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
out_dtype = rng.choice(a=incorrect_types)
elif error_name == ErrorIf.WrongInputType:
@@ -5669,6 +5724,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 76e7388..31a0ff0 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = {
DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
+ DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
+ DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
}
@@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT32,
DType.INT48,
)
+ elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ )
else:
# Assume all types but the input type are incorrect
incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
@@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f):
return f32_bits[16:] == "0" * 16
+def float32_is_valid_float8(f):
+ """Return True if float value is valid float8."""
+ f32_bits = get_float32_bitstring(f)
+ return f32_bits[8:] == "0" * 24
+
+
def get_float32_bitstring(f):
"""Return a big-endian string of bits representing a 32 bit float."""
f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
@@ -232,6 +250,30 @@ def float32_to_bfloat16(f):
return struct.unpack("@f", fp_bytes)[0] # native byteorder
+def float32_to_fp8e4m3(f):
+ """Turns fp32 value into fp8e4m3"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0] # native byteorder
+
+
+def float32_to_fp8e5m2(f):
+ """Turns fp32 value into fp8e5m2"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0]
+
+
vect_f32_to_bf16 = np.vectorize(
float32_to_bfloat16, otypes=(np.float32,)
) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e4m3 = np.vectorize(
+ float32_to_fp8e4m3, otypes=(np.float32,)
+) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e5m2 = np.vectorize(
+ float32_to_fp8e5m2, otypes=(np.float32,)
+) # Numpy vectorize: applies function to vector faster than looping