diff options
Diffstat (limited to 'verif/generator')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 10 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 35 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 80 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 45 | ||||
-rw-r--r-- | verif/generator/tosa_verif_build_tests.py | 4 |
5 files changed, 155 insertions, 19 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 0203513..932ad55 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -776,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] in (DType.FP16, DType.FP32): + if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -1130,6 +1130,8 @@ class TosaArgGen: accum_dtypes = [DType.INT48] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] + elif dtype == DType.BF16: + accum_dtypes = [DType.FP32] elif dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: @@ -1304,7 +1306,7 @@ class TosaArgGen: accum_dtypes = [DType.INT32] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] - elif dtype == DType.FP32: + elif dtype == DType.BF16 or dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" @@ -1417,6 +1419,8 @@ class TosaArgGen: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: dtypeList = [DType.INT8, DType.INT16, DType.INT32] + elif inDtype == DType.BF16: + dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP32: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: @@ -1826,6 +1830,8 @@ class TosaArgGen: outputDTypeList = [DType.INT48] elif dtype == DType.FP16: outputDTypeList = [DType.FP16] + elif dtype == DType.BF16: + outputDTypeList = [DType.BF16] elif dtype == DType.FP32: outputDTypeList = [DType.FP32] elif error_name == ErrorIf.WrongInputType: diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index abe1a97..a850699 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -158,6 +158,15 @@ class TosaErrorIfArgGen: DType.INT48, DType.FP32, ) + elif dtype == DType.BF16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + ) elif dtype == DType.FP32: incorrect_types = ( DType.INT4, @@ -299,8 +308,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32] + if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -425,6 +434,7 @@ class TosaErrorValidator: and output_dtype != DType.INT48 ) or (input_dtype == DType.FP16 and output_dtype != DType.FP16) + or (input_dtype == DType.BF16 and output_dtype != DType.BF16) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True @@ -442,25 +452,29 @@ class TosaErrorValidator: input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) ) + or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: if ( - input_dtype not in (DType.FP16, DType.FP32) + input_dtype not in (DType.FP16, DType.BF16, DType.FP32) and output_dtype != DType.INT32 ): error_result = True elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True + elif input_dtype == DType.BF16 and output_dtype != DType.BF16: + error_result = True elif input_dtype == DType.FP32 and output_dtype != DType.FP32: error_result = True @@ -489,6 +503,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -500,6 +515,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -511,6 +527,7 @@ class TosaErrorValidator: DType.INT16, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -518,6 +535,10 @@ class TosaErrorValidator: and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( + input_dtype == DType.BF16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + ) + or ( input_dtype == DType.FP32 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) @@ -537,6 +558,8 @@ class TosaErrorValidator: and output_dtype != DType.INT48 or input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) + or input_dtype == DType.BF16 + and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 ): @@ -2316,12 +2339,14 @@ class TosaInvalidValidator: not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) + and not (input_dtype == DType.BF16 and output_dtype == DType.BF16) and not (input_dtype == DType.FP32 and output_dtype == DType.FP32) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] ) else: # Invalid resize mode diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 78d86cd..95e06ed 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -16,6 +16,7 @@ from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import DTYPE_ATTRIBUTES from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes +from generator.tosa_utils import vect_f32_to_bf16 from tosa.DType import DType from tosa.Op import Op @@ -84,6 +85,10 @@ class TosaTestGen: ) elif dtype == DType.FP16: return np.float16(self.rng.random(size=shape)) + elif dtype == DType.BF16: + f32_tensor = np.float32(self.rng.random(size=shape)) + # Floor the last 16 bits of each f32 value + return np.float32(vect_f32_to_bf16(f32_tensor)) elif dtype == DType.FP32: return np.float32(self.rng.random(size=shape)) else: @@ -134,6 +139,9 @@ class TosaTestGen: elif dtype == DType.FP16: rand_f32 = self.rng.random() return np.float16(rand_f32) + elif dtype == DType.BF16: + rand_f32 = self.rng.random() + return vect_f32_to_bf16(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) # TOSA specific INT4 weight range from -7 to 7 @@ -324,7 +332,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype not in (DType.FP16, DType.FP32): + if a.dtype not in (DType.FP16, DType.BF16, DType.FP32): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -1043,7 +1051,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype in (DType.FP16, DType.FP32): + if a.dtype in (DType.FP16, DType.BF16, DType.FP32): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1859,7 +1867,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FP32, DType.FP16, DType.INT32): + if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2398,7 +2406,7 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FP32, DType.FP16] + TYPE_FP = [DType.FP32, DType.FP16, DType.BF16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 TYPE_INT_FP = [ @@ -2406,13 +2414,20 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.FP16, + DType.BF16, DType.FP32, ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FI32 = [ + DType.FP32, + DType.FP16, + DType.BF16, + DType.INT32, + ] # floating-types and INT32 TYPE_FIB = [ DType.FP16, + DType.BF16, DType.FP32, DType.INT8, DType.INT16, @@ -2421,7 +2436,7 @@ class TosaTestGen: ] TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ @@ -2430,6 +2445,7 @@ class TosaTestGen: [DType.INT16, DType.INT8, DType.INT48], [DType.FP16, DType.FP16, DType.FP16], [DType.FP16, DType.FP16, DType.FP32], + [DType.BF16, DType.BF16, DType.FP32], [DType.FP32, DType.FP32, DType.FP32], ] @@ -3448,7 +3464,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": (DType.FP16, DType.FP32, DType.INT32), + "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3635,7 +3651,14 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32), + "types": ( + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP32, + ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3676,7 +3699,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32), + "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3712,6 +3735,7 @@ class TosaTestGen: ), "types": ( DType.FP16, + DType.BF16, DType.FP32, DType.INT8, DType.INT16, @@ -3842,6 +3866,8 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, + DType.FP16, + DType.BF16, DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) @@ -3872,6 +3898,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3900,6 +3928,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3929,6 +3959,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] outputDType = rng.choice(wrong_dtypes) else: @@ -3955,6 +3987,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3987,6 +4021,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -4189,6 +4225,7 @@ class OutputShaper: DType.INT48, DType.FP32, DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4226,6 +4263,8 @@ class OutputShaper: DType.INT16, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -4234,8 +4273,12 @@ class OutputShaper: DType.INT16, DType.INT32, DType.FP32, + DType.FP16, + DType.BF16, ) - elif a.dtype == DType.FP32 or a.dtype == DType.FP16: + elif ( + a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16 + ): incorrect_types = ( DType.INT4, DType.INT8, @@ -4278,6 +4321,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, } wrong_dtypes = list(all_dtypes - set([input1.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4306,6 +4351,7 @@ class OutputShaper: DType.INT48, DType.FP32, DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4329,6 +4375,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4347,6 +4395,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4383,6 +4433,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4411,6 +4463,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4435,6 +4489,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([values.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4462,6 +4518,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4483,6 +4541,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes.remove(output_dtype) output_dtype = rng.choice(wrong_dtypes) diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 104d9bb..d79ab3c 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -1,5 +1,9 @@ # Copyright (c) 2021-2022, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import struct +import sys + +import numpy as np from tosa.DType import DType # Maximum dimension size for output and inputs for RESIZE @@ -15,6 +19,7 @@ DTYPE_ATTRIBUTES = { DType.INT32: {"str": "i32", "width": 32}, DType.INT48: {"str": "i48", "width": 48}, DType.FP16: {"str": "f16", "width": 16}, + DType.BF16: {"str": "bf16", "width": 16}, DType.FP32: {"str": "f32", "width": 32}, } @@ -125,7 +130,11 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.FP32, DType.FP16, ) - elif input_dtype == DType.FP32 or input_dtype == DType.FP16: + elif ( + input_dtype == DType.FP32 + or input_dtype == DType.FP16 + or input_dtype == DType.BF16 + ): incorrect_types = ( DType.INT4, DType.INT8, @@ -134,3 +143,37 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT48, ) return rng.choice(a=incorrect_types) + + +def float32_is_valid_bfloat16(f): + """Return True if float value is valid bfloat16.""" + f32_bits = get_float32_bitstring(f) + return f32_bits[16:] == "0" * 16 + + +def get_float32_bitstring(f): + """Return a big-endian string of bits representing a 32 bit float.""" + f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0] + return f"{f32_bits_as_int:032b}" + + +def float32_to_bfloat16(f): + """Turns fp32 value into bfloat16 by flooring. + + Floors the least significant 16 bits of the input + fp32 value and returns this valid bfloat16 representation as fp32. + For simplicity during bit-wrangling, ignores underlying system + endianness and interprets as big-endian. + Returns a bf16-valid float following system's native byte order. + """ + f32_bits = get_float32_bitstring(f) + f32_floored_bits = f32_bits[:16] + "0" * 16 + + # Assume sys.byteorder matches system's underlying float byteorder + fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] # native byteorder + + +vect_f32_to_bf16 = np.vectorize( + float32_to_bfloat16, otypes=(np.float32,) +) # NumPy vectorize: applies function to vector faster than looping diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py index 2fafacb..ab78b1a 100644 --- a/verif/generator/tosa_verif_build_tests.py +++ b/verif/generator/tosa_verif_build_tests.py @@ -5,6 +5,7 @@ import re from generator.tosa_test_gen import TosaTestGen from serializer.tosa_serializer import dtype_str_to_val +from serializer.tosa_serializer import DTypeNames # Used for parsing a comma-separated list of integers in a string @@ -150,13 +151,14 @@ def parseArgs(argv): help="Create tests with a particular input tensor rank", ) + # Used for parsing a comma-separated list of integers in a string parser.add_argument( "--target-dtype", dest="target_dtypes", action="append", default=None, type=lambda x: dtype_str_to_val(x), - help="Create test with a particular DType (may be repeated)", + help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)", ) parser.add_argument( |