From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- verif/generator/tosa_test_gen.py | 80 +++++++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 10 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 78d86cd..95e06ed 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -16,6 +16,7 @@ from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import DTYPE_ATTRIBUTES from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes +from generator.tosa_utils import vect_f32_to_bf16 from tosa.DType import DType from tosa.Op import Op @@ -84,6 +85,10 @@ class TosaTestGen: ) elif dtype == DType.FP16: return np.float16(self.rng.random(size=shape)) + elif dtype == DType.BF16: + f32_tensor = np.float32(self.rng.random(size=shape)) + # Floor the last 16 bits of each f32 value + return np.float32(vect_f32_to_bf16(f32_tensor)) elif dtype == DType.FP32: return np.float32(self.rng.random(size=shape)) else: @@ -134,6 +139,9 @@ class TosaTestGen: elif dtype == DType.FP16: rand_f32 = self.rng.random() return np.float16(rand_f32) + elif dtype == DType.BF16: + rand_f32 = self.rng.random() + return vect_f32_to_bf16(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) # TOSA specific INT4 weight range from -7 to 7 @@ -324,7 +332,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype not in (DType.FP16, DType.FP32): + if a.dtype not in (DType.FP16, DType.BF16, DType.FP32): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -1043,7 +1051,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype in (DType.FP16, DType.FP32): + if a.dtype in (DType.FP16, DType.BF16, DType.FP32): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1859,7 +1867,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FP32, DType.FP16, DType.INT32): + if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2398,7 +2406,7 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FP32, DType.FP16] + TYPE_FP = [DType.FP32, DType.FP16, DType.BF16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 TYPE_INT_FP = [ @@ -2406,13 +2414,20 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.FP16, + DType.BF16, DType.FP32, ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FI32 = [ + DType.FP32, + DType.FP16, + DType.BF16, + DType.INT32, + ] # floating-types and INT32 TYPE_FIB = [ DType.FP16, + DType.BF16, DType.FP32, DType.INT8, DType.INT16, @@ -2421,7 +2436,7 @@ class TosaTestGen: ] TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ @@ -2430,6 +2445,7 @@ class TosaTestGen: [DType.INT16, DType.INT8, DType.INT48], [DType.FP16, DType.FP16, DType.FP16], [DType.FP16, DType.FP16, DType.FP32], + [DType.BF16, DType.BF16, DType.FP32], [DType.FP32, DType.FP32, DType.FP32], ] @@ -3448,7 +3464,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": (DType.FP16, DType.FP32, DType.INT32), + "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3635,7 +3651,14 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32), + "types": ( + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP32, + ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3676,7 +3699,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32), + "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3712,6 +3735,7 @@ class TosaTestGen: ), "types": ( DType.FP16, + DType.BF16, DType.FP32, DType.INT8, DType.INT16, @@ -3842,6 +3866,8 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, + DType.FP16, + DType.BF16, DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) @@ -3872,6 +3898,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3900,6 +3928,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3929,6 +3959,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] outputDType = rng.choice(wrong_dtypes) else: @@ -3955,6 +3987,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3987,6 +4021,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -4189,6 +4225,7 @@ class OutputShaper: DType.INT48, DType.FP32, DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4226,6 +4263,8 @@ class OutputShaper: DType.INT16, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -4234,8 +4273,12 @@ class OutputShaper: DType.INT16, DType.INT32, DType.FP32, + DType.FP16, + DType.BF16, ) - elif a.dtype == DType.FP32 or a.dtype == DType.FP16: + elif ( + a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16 + ): incorrect_types = ( DType.INT4, DType.INT8, @@ -4278,6 +4321,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, } wrong_dtypes = list(all_dtypes - set([input1.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4306,6 +4351,7 @@ class OutputShaper: DType.INT48, DType.FP32, DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4329,6 +4375,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4347,6 +4395,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4383,6 +4433,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4411,6 +4463,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4435,6 +4489,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([values.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4462,6 +4518,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4483,6 +4541,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes.remove(output_dtype) output_dtype = rng.choice(wrong_dtypes) -- cgit v1.2.1