diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 95 |
1 files changed, 76 insertions, 19 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 4ead982..bc931dc 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -76,7 +76,7 @@ class TosaTestGen: return tuple(sorted(vals)) self.random_float_range = {} - for dtype in (DType.FP32, DType.FP16, DType.BF16): + for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): self.random_float_range[dtype] = convertFPRange( args.tensor_fp_value_range, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], @@ -152,7 +152,7 @@ class TosaTestGen: # Returns dtype value range boundaries (low, high) # The high boundary is excluded in the range # unless high_inclusive is True - if dtype in (DType.FP32, DType.FP16, DType.BF16): + if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): return self.random_float_range[dtype] elif dtype == DType.BOOL: rng = (0, 2) @@ -197,7 +197,13 @@ class TosaTestGen: return np.uint8(self.rng.integers(low=low, high=high, size=shape)) elif dtype in (DType.INT48, DType.SHAPE): return np.int64(self.rng.integers(low=low, high=high, size=shape)) - elif dtype in (DType.FP16, DType.BF16, DType.FP32): + elif dtype in ( + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ): f_tensor = self.rng.uniform(low=low, high=high, size=shape) if dtype == DType.FP16: @@ -207,6 +213,10 @@ class TosaTestGen: if dtype == DType.BF16: # Floor the last 16 bits of each f32 value return np.float32(gtu.vect_f32_to_bf16(f32_tensor)) + elif dtype == DType.FP8E4M3: + return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor)) + elif dtype == DType.FP8E5M2: + return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor)) else: return f32_tensor else: @@ -266,6 +276,12 @@ class TosaTestGen: elif dtype == DType.BF16: rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) return gtu.vect_f32_to_bf16(rand_f32) + elif dtype == DType.FP8E4M3: + rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) + return gtu.vect_f32_to_fp8e4m3(rand_f32) + elif dtype == DType.FP8E5M2: + rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) + return gtu.vect_f32_to_fp8e5m2(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) elif dtype == DType.INT48 or dtype == DType.SHAPE: @@ -1408,8 +1424,11 @@ class TosaTestGen: max_val = max_val.astype(np.float32) attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val) - else: + elif a.dtype in (DType.INT8, DType.INT16): attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0) + else: + # to avoid internal error for incorrect input types + attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -3190,7 +3209,13 @@ class TosaTestGen: ] TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] + TYPE_NARROW_INT_FP = [ + DType.INT8, + DType.INT16, + DType.FP16, + DType.BF16, + DType.FP32, + ] # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ @@ -3201,6 +3226,8 @@ class TosaTestGen: [DType.FP16, DType.FP16, DType.FP32], [DType.BF16, DType.BF16, DType.FP32], [DType.FP32, DType.FP32, DType.FP32], + [DType.FP8E4M3, DType.FP8E4M3, DType.FP16], + [DType.FP8E5M2, DType.FP8E5M2, DType.FP16], ] DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK) @@ -3217,7 +3244,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, @@ -3244,7 +3271,7 @@ class TosaTestGen: TosaArgGen.agPooling, ), "qgen": TosaQuantGen.qgUnary, - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, @@ -3402,7 +3429,7 @@ class TosaTestGen: TosaArgGen.agMatMul, ), "qgen": TosaQuantGen.qgMatmul, - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, @@ -3425,7 +3452,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agPooling, ), - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, @@ -4389,7 +4416,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgConcat, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -4413,7 +4440,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgPad, TosaArgGen.agPad, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero, @@ -4437,7 +4464,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -4456,7 +4483,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReshape, TosaArgGen.agReshape, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, @@ -4477,7 +4504,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, @@ -4500,7 +4527,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgSlice, TosaArgGen.agSlice, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( # TODO Turn off these error categories for now as the reference # model cannot allocate memory space for empty tensor. We probably @@ -4532,7 +4559,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgTile, TosaArgGen.agTile, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -4555,7 +4582,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agTranspose, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, @@ -4581,7 +4608,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agNone, ), - "types": TYPE_FIB + [DType.INT48], + "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2], "data_gen": { "fp": (gtu.DataGenType.PSEUDO_RANDOM,), }, @@ -4618,6 +4645,8 @@ class TosaTestGen: DType.FP16, DType.BF16, DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, @@ -4640,7 +4669,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgScatter, TosaArgGen.agNone, ), - "types": TYPE_INT_FP, + "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -4709,6 +4738,8 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.BOOL, + DType.FP8E4M3, + DType.FP8E5M2, ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, @@ -5141,6 +5172,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -5194,6 +5227,8 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: excludes = [DType.FP16, DType.FP32] + if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]: + excludes = [DType.FP16] else: excludes = [out_dtype] wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) @@ -5344,6 +5379,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -5383,6 +5420,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -5393,6 +5432,20 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ) + elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ) elif ( a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16 @@ -5403,6 +5456,8 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, + DType.FP8E4M3, + DType.FP8E5M2, ) out_dtype = rng.choice(a=incorrect_types) elif error_name == ErrorIf.WrongInputType: @@ -5669,6 +5724,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) |