aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py95
1 files changed, 76 insertions, 19 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 4ead982..bc931dc 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -76,7 +76,7 @@ class TosaTestGen:
return tuple(sorted(vals))
self.random_float_range = {}
- for dtype in (DType.FP32, DType.FP16, DType.BF16):
+ for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
self.random_float_range[dtype] = convertFPRange(
args.tensor_fp_value_range,
TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
@@ -152,7 +152,7 @@ class TosaTestGen:
# Returns dtype value range boundaries (low, high)
# The high boundary is excluded in the range
# unless high_inclusive is True
- if dtype in (DType.FP32, DType.FP16, DType.BF16):
+ if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
return self.random_float_range[dtype]
elif dtype == DType.BOOL:
rng = (0, 2)
@@ -197,7 +197,13 @@ class TosaTestGen:
return np.uint8(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (DType.INT48, DType.SHAPE):
return np.int64(self.rng.integers(low=low, high=high, size=shape))
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
f_tensor = self.rng.uniform(low=low, high=high, size=shape)
if dtype == DType.FP16:
@@ -207,6 +213,10 @@ class TosaTestGen:
if dtype == DType.BF16:
# Floor the last 16 bits of each f32 value
return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+ elif dtype == DType.FP8E4M3:
+ return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
+ elif dtype == DType.FP8E5M2:
+ return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
else:
return f32_tensor
else:
@@ -266,6 +276,12 @@ class TosaTestGen:
elif dtype == DType.BF16:
rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
return gtu.vect_f32_to_bf16(rand_f32)
+ elif dtype == DType.FP8E4M3:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e4m3(rand_f32)
+ elif dtype == DType.FP8E5M2:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e5m2(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
elif dtype == DType.INT48 or dtype == DType.SHAPE:
@@ -1408,8 +1424,11 @@ class TosaTestGen:
max_val = max_val.astype(np.float32)
attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
- else:
+ elif a.dtype in (DType.INT8, DType.INT16):
attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
+ else:
+ # to avoid internal error for incorrect input types
+ attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -3190,7 +3209,13 @@ class TosaTestGen:
]
TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ TYPE_NARROW_INT_FP = [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
@@ -3201,6 +3226,8 @@ class TosaTestGen:
[DType.FP16, DType.FP16, DType.FP32],
[DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
+ [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
+ [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
]
DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
@@ -3217,7 +3244,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -3244,7 +3271,7 @@ class TosaTestGen:
TosaArgGen.agPooling,
),
"qgen": TosaQuantGen.qgUnary,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -3402,7 +3429,7 @@ class TosaTestGen:
TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWrongRank,
@@ -3425,7 +3452,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -4389,7 +4416,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgConcat,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4413,7 +4440,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgPad,
TosaArgGen.agPad,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evPadSmallerZero,
@@ -4437,7 +4464,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4456,7 +4483,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReshape,
TosaArgGen.agReshape,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evTensorSizeInputOutputMismatch,
TosaErrorValidator.evWrongInputType,
@@ -4477,7 +4504,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -4500,7 +4527,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgSlice,
TosaArgGen.agSlice,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
# TODO Turn off these error categories for now as the reference
# model cannot allocate memory space for empty tensor. We probably
@@ -4532,7 +4559,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgTile,
TosaArgGen.agTile,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4555,7 +4582,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTranspose,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evIndexOutsideBounds,
TosaErrorValidator.evIndexUsedTwice,
@@ -4581,7 +4608,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
- "types": TYPE_FIB + [DType.INT48],
+ "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
"data_gen": {
"fp": (gtu.DataGenType.PSEUDO_RANDOM,),
},
@@ -4618,6 +4645,8 @@ class TosaTestGen:
DType.FP16,
DType.BF16,
DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -4640,7 +4669,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgScatter,
TosaArgGen.agNone,
),
- "types": TYPE_INT_FP,
+ "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4709,6 +4738,8 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.BOOL,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -5141,6 +5172,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -5194,6 +5227,8 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
+ if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ excludes = [DType.FP16]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
@@ -5344,6 +5379,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -5383,6 +5420,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -5393,6 +5432,20 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ )
+ elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif (
a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
@@ -5403,6 +5456,8 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
out_dtype = rng.choice(a=incorrect_types)
elif error_name == ErrorIf.WrongInputType:
@@ -5669,6 +5724,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)