# Copyright (c) 2021-2022, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import struct import sys import numpy as np from tosa.DType import DType # Maximum dimension size for output and inputs for RESIZE MAX_RESIZE_DIMENSION = 16384 DTYPE_ATTRIBUTES = { DType.BOOL: {"str": "b", "width": 1}, DType.INT4: {"str": "i4", "width": 4}, DType.INT8: {"str": "i8", "width": 8}, DType.UINT8: {"str": "u8", "width": 8}, DType.INT16: {"str": "i16", "width": 16}, DType.UINT16: {"str": "u16", "width": 16}, DType.INT32: {"str": "i32", "width": 32}, DType.INT48: {"str": "i48", "width": 48}, DType.FP16: {"str": "f16", "width": 16}, DType.BF16: {"str": "bf16", "width": 16}, DType.FP32: {"str": "f32", "width": 32}, } def valueToName(item, value): """Get the name of an attribute with the given value. This convenience function is needed to print meaningful names for the values of the tosa.Op.Op and tosa.DType.DType classes. This would not be necessary if they were subclasses of Enum, or IntEnum, which, sadly, they are not. Args: item: The class, or object, to find the value in value: The value to find Example, to get the name of a DType value: name = valueToName(DType, DType.INT8) # returns 'INT8' name = valueToName(DType, 4) # returns 'INT8' Returns: The name of the first attribute found with a matching value, Raises: ValueError if the value is not found """ for attr in dir(item): if getattr(item, attr) == value: return attr raise ValueError(f"value ({value}) not found") def allDTypes(*, excludes=None): """Get a set of all DType values, optionally excluding some values. This convenience function is needed to provide a sequence of DType values. This would be much easier if DType was a subclass of Enum, or IntEnum, as we could then iterate over the values directly, instead of using dir() to find the attributes and then check if they are what we want. Args: excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL]) Returns: A set of DType values """ excludes = () if not excludes else excludes return { getattr(DType, t) for t in dir(DType) if not callable(getattr(DType, t)) and not t.startswith("__") and getattr(DType, t) not in excludes } def usableDTypes(*, excludes=None): """Get a set of usable DType values, optionally excluding some values. Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in addition to the excludes specified by the caller, as the serializer lib does not support them. If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes instead. Args: excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL]) Returns: A set of DType values """ omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16} omit.update(excludes if excludes else ()) return allDTypes(excludes=omit) def product(shape): value = 1 for n in shape: value *= n return value def get_accum_dtype_from_tgTypes(dtypes): # Get accumulate data-type from the test generator's defined types assert isinstance(dtypes, list) or isinstance(dtypes, tuple) return dtypes[-1] def get_wrong_output_type(op_name, rng, input_dtype): if op_name == "fully_connected" or op_name == "matmul": if input_dtype == DType.INT8: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FP32, DType.FP16, ) elif input_dtype == DType.INT16: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FP32, DType.FP16, ) elif ( input_dtype == DType.FP32 or input_dtype == DType.FP16 or input_dtype == DType.BF16 ): incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, ) return rng.choice(a=incorrect_types) def float32_is_valid_bfloat16(f): """Return True if float value is valid bfloat16.""" f32_bits = get_float32_bitstring(f) return f32_bits[16:] == "0" * 16 def get_float32_bitstring(f): """Return a big-endian string of bits representing a 32 bit float.""" f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0] return f"{f32_bits_as_int:032b}" def float32_to_bfloat16(f): """Turns fp32 value into bfloat16 by flooring. Floors the least significant 16 bits of the input fp32 value and returns this valid bfloat16 representation as fp32. For simplicity during bit-wrangling, ignores underlying system endianness and interprets as big-endian. Returns a bf16-valid float following system's native byte order. """ f32_bits = get_float32_bitstring(f) f32_floored_bits = f32_bits[:16] + "0" * 16 # Assume sys.byteorder matches system's underlying float byteorder fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder) return struct.unpack("@f", fp_bytes)[0] # native byteorder vect_f32_to_bf16 = np.vectorize( float32_to_bfloat16, otypes=(np.float32,) ) # NumPy vectorize: applies function to vector faster than looping