# Copyright (c) 2021-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import itertools import math import warnings import generator.tosa_utils as gtu import numpy as np from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from serializer.tosa_serializer import DTypeNames from tosa.DType import DType from tosa.Op import Op from tosa.ResizeMode import ResizeMode # DTypeNames, DType, Op and ResizeMode are convenience variables to the # flatc-generated types that should be enums, but aren't class TosaQuantGen: """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion. """ def __init__(self): pass @staticmethod def getZeroPoint(testGen, dtype, error_name=None): if dtype == DType.INT8: if testGen.args.zeropoint is not None: return min(127, max(-128, testGen.args.zeropoint)) return testGen.randInt(-128, 128) elif dtype == DType.UINT8: if testGen.args.zeropoint is not None: return min(255, max(0, testGen.args.zeropoint)) return testGen.randInt(0, 256) elif error_name in [ ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero, ]: zero_point = testGen.randInt(-128, 128) if zero_point == 0: zero_point = 1 return zero_point return 0 @staticmethod def qgUnary(testGen, op, dtype, error_name=None): if error_name == ErrorIf.InputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype, error_name), TosaQuantGen.getZeroPoint(testGen, dtype), ] elif error_name == ErrorIf.OutputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype), TosaQuantGen.getZeroPoint(testGen, dtype, error_name), ] else: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype), TosaQuantGen.getZeroPoint(testGen, dtype), ] return qinfo @staticmethod def qgConv(testGen, op, dtype_or_dtypeList, error_name=None): if isinstance(dtype_or_dtypeList, list): # a list of [input, weights, accumulator] dtypes dtypeList = dtype_or_dtypeList else: # an int, [input, weights, accumulator] dtypes are the same dtypeList = [dtype_or_dtypeList] * 3 if error_name == ErrorIf.InputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name), TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), ] elif error_name == ErrorIf.WeightZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name), ] else: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), ] return qinfo @staticmethod def qgMatmul(testGen, op, dtype, error_name=None): if error_name == ErrorIf.InputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype, error_name), TosaQuantGen.getZeroPoint(testGen, dtype, error_name), ] else: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype), TosaQuantGen.getZeroPoint(testGen, dtype), ] return qinfo @staticmethod def computeMultiplierAndShift(scaleFp, scale32): # Derived from computeMultiplierAndShiftTosaScale32 # Provide a floating-point scaling factor and the scale32 parameter # to compute the multiplier and shift if scale32: scaleBits = 31 else: scaleBits = 15 m, shift = math.frexp(scaleFp) if scaleFp < 0.0: m = -m multiplier = round(m * (1 << scaleBits)) assert multiplier <= (1 << scaleBits) if multiplier == (1 << scaleBits): multiplier = multiplier // 2 shift = shift + 1 shift = (-shift) + scaleBits # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format( # scaleFp, scaleBits, m, multiplier, shift)) # Adjust multiplier such that shift is in allowed value range. if shift == 0: multiplier = multiplier // 4 shift = shift + 2 elif shift == 1: multiplier = multiplier // 2 shift = shift + 1 elif shift == 63: multiplier = multiplier * 2 shift = shift - 1 assert multiplier <= (1 << scaleBits) assert shift >= 2 and shift <= 62 return multiplier, shift class TosaTensorGen: """Tensor generators create a shape list for the placeholder and const tensor data operands for the operator. The actual random data is generated separately for each test. """ def __init__(self): pass @staticmethod def tgBasic(testGen, opName, rank, error_name=None): pl, const = opName["operands"] shape = testGen.makeShape(rank) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: shape = TosaErrorIfArgGen.eiRestrictDimensions(shape) shape_list = [] for i in range(pl + const): shape_list.append(shape.copy()) # Generates an input rank mismatch for operators with more than one input if error_name == ErrorIf.RankMismatch: if rank == 1 and i != 1: shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) elif i != 1: shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1])) return shape_list @staticmethod def tgNHWC(testGen, opName, rank, error_name=None): pl, const = opName["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 shape = testGen.makeShape(rank) shape = testGen.constrictBatchSize(shape) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name and error_name != ErrorIf.MaxDimExceeded: shape = TosaErrorIfArgGen.eiRestrictDimensions(shape) shape_list = [] for i in range(pl + const): shape_list.append(shape.copy()) return shape_list @staticmethod def tgGather(testGen, opName, rank, error_name=None): pl, const = opName["operands"] assert pl == 2 assert const == 0 if error_name != ErrorIf.WrongRank: assert rank == 3 values_shape = testGen.makeShape(rank) values_shape = testGen.constrictBatchSize(values_shape) N = values_shape[0] W = testGen.makeDimension() indices_shape = [N, W] shape_list = [values_shape, indices_shape] return shape_list @staticmethod def tgScatter(testGen, opName, rank, error_name=None): pl, const = opName["operands"] assert pl == 3 assert const == 0 if error_name != ErrorIf.WrongRank: assert rank == 3 values_in_shape = testGen.makeShape(rank) values_in_shape = testGen.constrictBatchSize(values_in_shape) N = values_in_shape[0] K = values_in_shape[1] C = values_in_shape[2] # Make sure W is not greater than K, as we can only write each output index # once (having a W greater than K means that you have to repeat a K index) W_min = min(testGen.args.tensor_shape_range[0], K) W_max = min(testGen.args.tensor_shape_range[1], K) W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min input_shape = [N, W, C] shape_list = [] shape_list.append(values_in_shape) shape_list.append([N, W]) # indices shape_list.append(input_shape) return shape_list @staticmethod def tgBroadcastFuzz(testGen, op, rank, error_name=None): shape = testGen.makeShape(rank) pl, const = op["operands"] shape_list = [] # Choose one of the inputs to broadcast # Note: Simplifies OutputShaper code if we don't change first shape for errors bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const) fuzz_idx = testGen.randInt(0, rank) for i in range(pl + const): shape_bcast = shape.copy() # To test broadcasting, the chosen fuzz index dimension should not be 1 if shape_bcast[fuzz_idx] == 1: shape_bcast[fuzz_idx] += 1 # If the chosen input, pick a random index to broadcast if i == bcast_idx: if error_name == ErrorIf.RankMismatch: # Add one rank to the shape (or more for rank of 1) extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1 shape_bcast = np.concatenate( (shape_bcast, testGen.makeShape(extra_ranks)) ) if rank != 1: # Either keep the extra rank, or remove it new_len = testGen.rng.choice([-2, len(shape_bcast)]) shape_bcast = shape_bcast[:new_len] elif error_name == ErrorIf.BroadcastShapesMismatch: shape_bcast[fuzz_idx] += 2 else: shape_bcast[fuzz_idx] = 1 shape_list.append(shape_bcast) return shape_list @staticmethod def tgConv2D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 # IFM dimensions are NHWC ifm_shape = testGen.makeShape(rank) ifm_shape = testGen.constrictBatchSize(ifm_shape) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter height/width from the operator parameters filter_hw = op["filter"] # Generate a random OFM depth ofm_depth = testGen.makeDimension() # The filter dimensions are OHWI filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]]) # The bias is OC bias_shape = np.asarray([ofm_depth]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgConv3D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 5 # IFM dimensions are NDHWC ifm_shape = testGen.makeShape(rank) ifm_shape = testGen.constrictBatchSize(ifm_shape) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter depth/height/width from the operator parameters filter_dhw = op["filter"] # Generate a random OFM channel ofm_channel = testGen.makeDimension() # The filter dimensions are ODHWI filter_shape = np.asarray( [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]] ) # The bias is OC bias_shape = np.asarray([ofm_channel]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgTransposeConv2D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 # IFM dimensions are NHWC ifm_shape = testGen.makeShape(rank) ifm_shape = testGen.constrictBatchSize(ifm_shape) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter height/width from the operator parameters filter_hw = op["filter"] # Generate a random OFM depth ofm_depth = testGen.makeDimension() # The filter dimensions are OHWI filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]]) # The bias is OC bias_shape = np.asarray([ofm_depth]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgDepthwiseConv2D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 assert pl == 1 and const == 2 # IFM dimensions are NHWC ifm_shape = testGen.makeShape(rank) ifm_shape = testGen.constrictBatchSize(ifm_shape) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter height/width from the operator parameters # Filter is KH, HW, C, M filter_hw = op["filter"] # Generate a random OFM depth, but don't let it get too big because # the output depth is M * C filter_m = ( testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4) ) + 1 # The filter dimensions are HWCM filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m]) # The bias is M * C bias_shape = np.asarray([ifm_shape[3] * filter_m]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgFFT2d(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 3 assert pl == 2 and const == 0 # IFM dimensions are NHW ifm_shape = testGen.makeShape(rank) # Select nearest lower power of two from input height and width ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2)) ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2)) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape) # Generate an invalid kernel that is not a power of two if error_name == ErrorIf.KernelNotPowerOfTwo: inc_h = 2 if ifm_shape[1] == 1 else 1 inc_w = 2 if ifm_shape[2] == 1 else 1 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)] selected_inc = testGen.rng.choice(inc_choices) ifm_shape[1] += selected_inc[0] ifm_shape[2] += selected_inc[1] ifm_shape = testGen.constrictBatchSize(ifm_shape) ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()] if error_name == ErrorIf.FFTInputShapeMismatch: modify_shape = testGen.rng.choice([0, 1]) # Only modify kernel (H, W) modify_dim = testGen.rng.choice([1, 2]) ifm_shapes[modify_shape][modify_dim] *= 2 return [ifm_shapes[0], ifm_shapes[1]] @staticmethod def tgRFFT2d(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 3 assert pl == 1 and const == 0 # IFM dimensions are NHW ifm_shape = testGen.makeShape(rank) # Select nearest lower power of two from input height and width ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2)) ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2)) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape) # Generate an invalid kernel that is not a power of two if error_name == ErrorIf.KernelNotPowerOfTwo: # We must increment by 2 if current size is 1 inc_h = 2 if ifm_shape[1] == 1 else 1 inc_w = 2 if ifm_shape[2] == 1 else 1 inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)] selected_inc = testGen.rng.choice(inc_choices) ifm_shape[1] += selected_inc[0] ifm_shape[2] += selected_inc[1] ifm_shape = testGen.constrictBatchSize(ifm_shape) return [ifm_shape] @staticmethod def tgFullyConnected(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 2 input_shape = testGen.makeShape(rank) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape) filter_oc = testGen.rng.integers( low=testGen.args.tensor_shape_range[0], high=testGen.args.tensor_shape_range[1], size=1, )[0] filter_shape = np.asarray([filter_oc, input_shape[1]]) bias_shape = np.asarray([filter_oc]) return [input_shape, filter_shape, bias_shape] @staticmethod def tgMatmul(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 3 assert pl == 2 and const == 0 a_shape = testGen.makeShape(rank) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape) # Get a random number for b_oc even if target shape is defined b_oc = np.int32( testGen.rng.integers( low=testGen.args.tensor_shape_range[0], high=testGen.args.tensor_shape_range[1], size=1, ) )[0] # If N or H is large let b_oc be 1 to reduce output tensor size if max(a_shape) > 1000: b_oc = 1 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc]) return [a_shape, b_shape] @staticmethod def tgConcat(testGen, opName, rank, error_name=None): pl, const = opName["operands"] shape = testGen.makeShape(rank) # Create extra tensors to concat. # Take into account value of pl when getting maximum number of concats num_tensors = testGen.randInt(0, 4) shape_list = [] for i in range(pl + const + num_tensors): if error_name == ErrorIf.ConcatInputRankMismatch and i != 0: remove = testGen.rng.choice([True, False]) wrongShape = shape.copy() if remove and len(shape) > 1: wrongShape = wrongShape[1:] else: wrongShape = list(wrongShape) wrongShape.append(testGen.rng.integers(1, 10)) shape_list.append(wrongShape) else: shape_list.append(shape.copy()) return shape_list @staticmethod def tgConcatConstInput(testGen, shapeList, axis, error_name=None): if error_name in [ ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch, ]: return shapeList # Split concat shape along axis to allow for multiple const inputs # without making too many large tensors if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList): # If axis can't be split we still need to invalidate other dimensions if error_name == ErrorIf.ConcatInputDimMismatch: for shape in shapeList[1:]: # Negative test shapeLists are created individually for each test, # so no need to copy the shape before altering it. shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10) return shapeList # Create copy of shape we are going to split (so we don't alter shapeList) shape = shapeList[0].copy() # Add original shape as first input new_shapeList = [shape.copy()] length_on_axis = shape[axis] remaining_length = length_on_axis for i in range(len(shapeList) - 2): # Calculate split on axis and remaining value split_shape_val = int(shape[axis] / 2) remaining_length = remaining_length - split_shape_val # Append new shape, and set remaining shape shape[axis] = split_shape_val new_shapeList.append(shape.copy()) # invalidate dimensions if error_name == ErrorIf.ConcatInputDimMismatch: shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10) else: shape[axis] = remaining_length if i == len(shapeList) - 3: new_shapeList.append(shape.copy()) return new_shapeList class TosaTensorValuesGen: """Tensor Value generators create the random data for each tensor in each test.""" def __init__(self): pass class TVGInfo: """Enhanced tensor values information including data gen dict.""" def __init__(self, tensorList, dataGenDict): self.tensorList = tensorList self.dataGenDict = dataGenDict @staticmethod def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] tens = [] tens.extend( testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount]) ) tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:])) return tens # Default high value for random numbers TVG_FLOAT_HIGH_VALUE = { DType.FP32: (1 << 128) - (1 << (127 - 23)), DType.FP16: (1 << 16) - (1 << (15 - 10)), DType.BF16: (1 << 128) - (1 << (127 - 7)), } # Default lowest normal values for random numbers TVG_FLOAT_LOW_VALUE = { DType.FP32: np.exp2(-126), DType.FP16: np.exp2(-14), DType.BF16: np.exp2(-126), } @staticmethod def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None): # Return a tuple of (low,high) data range values for the given data # type using a combination of per operator table limits, data limits # and user supplied ranges for FP numbers if dtype in highValueLookup: type_range = testGen.getDTypeRange(dtype, high_inclusive=True) high_val = highValueLookup[dtype] if lowValueLookup is not None and dtype in lowValueLookup: low_val = lowValueLookup[dtype] else: low_val = -high_val # Set the values to something that won't produce infinity whilst # respecting the default ranges if more/less than the low/high # values data_range = ( max(low_val, type_range[0]), min(high_val, type_range[1]), ) if data_range[0] > data_range[1]: # Invalid data range from low to high created due to user # constraints revert to using internal ranges as they are # known to work msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})" warnings.warn(msg) data_range = (low_val, high_val) return data_range return None @staticmethod def tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): # Variable inputs versus constants pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"] if "p_count" in argsDict: # Override for operators like CONCAT pCount = argsDict["p_count"] cCount = argsDict["c_count"] assert pCount + cCount == len( shapeList ), "Placeholders & Constant tensors must match shapes list" tens_ser_list = [] if ( error_name is not None or not gtu.dtypeIsSupportedByCompliance(dtypeList[0]) or "data_gen" not in testGen.TOSA_OP_LIST[opName] ): # Fall back to internal data gen when dealing with unsupported types or ops data_range = argsDict["data_range"] if "data_range" in argsDict else None for idx, info in enumerate(zip(shapeList, dtypeList)): roundMode = False shape, dtype = info if "data_range_list" in argsDict: data_range = argsDict["data_range_list"][idx]["range"] roundMode = ( "round" in argsDict["data_range_list"][idx] and argsDict["data_range_list"][idx]["round"] is True ) if data_range is not None and dtype not in ( DType.FP16, DType.FP32, DType.BF16, ): # Change from inclusive to exclusive range data_range = (data_range[0], data_range[1] + 1) # Ignore lazy data gen option and create data array using any range limits if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None: arr = np.int64(argsDict["fixed_data"][idx]) else: arr = testGen.getRandTensor(shape, dtype, data_range) if roundMode: arr = np.round(arr) if idx < pCount: tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr)) else: tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr)) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) # Create data generator meta-data dg_type = argsDict["dg_type"] tens_data = { "version": "0.1", "tensors": {}, } dg_tens_meta = tens_data["tensors"] for idx, shape in enumerate(shapeList): tens_meta = {} if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None: tens_meta["generator"] = gtu.DataGenType( gtu.DataGenType.FIXED_DATA ).name else: tens_meta["generator"] = gtu.DataGenType(dg_type).name tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"] tens_meta["shape"] = [int(i) for i in shape] tens_meta["input_pos"] = idx tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper() if idx < pCount: tens_meta["input_type"] = "VARIABLE" else: tens_meta["input_type"] = "CONSTANT" if dg_type == gtu.DataGenType.PSEUDO_RANDOM: info = {} if ( tens_meta["generator"] == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name ): info["data"] = [int(i) for i in argsDict["fixed_data"][idx]] tens_meta["fixed_data_info"] = info else: # TODO - generate seed for this generator based on test info["rng_seed"] = 42 data_range = None if "data_range_list" in argsDict: data_range = argsDict["data_range_list"][idx]["range"] if "round" in argsDict["data_range_list"][idx]: info["round"] = argsDict["data_range_list"][idx]["round"] elif "data_range" in argsDict: data_range = argsDict["data_range"] if data_range is None: data_range = testGen.getDTypeRange( dtypeList[idx], high_inclusive=True ) info["range"] = [str(v) for v in data_range] tens_meta["pseudo_random_info"] = info elif dg_type == gtu.DataGenType.DOT_PRODUCT: info = {} info["s"] = argsDict["s"] info["ks"] = int(argsDict["ks"]) if "acc_type" in argsDict: # Convert type number into JSON name info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][ "json" ] if "kernel" in argsDict: info["kernel"] = [int(k) for k in argsDict["kernel"]] if "axis" in argsDict: info["axis"] = int(argsDict["axis"]) tens_meta["dot_product_info"] = info else: # TODO - other data gen type assert False, "TODO: support other data gen types" # Using the finished generate config meta data - generate the data if # needed and assign a tensor name from the serializer # Need to generate data when not lazy or for the bias tensor as we need # to work out if the bias data is non-zero for compliance if not testGen.args.lazy_data_gen or ( idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT ): # Give this tensor a temporary name until we get one from the serializer temp_name = f"placeholder_{idx}" dg_tens_meta[temp_name] = tens_meta # Create data now using the temporary name to access meta details data = testGen.dgl.get_tensor_data(temp_name, tens_data) if tens_meta["data_type"] == "SHAPE": # Tensor type SHAPE and Numpy file type must be the same data = np.int64(data) # Remove the item as we will give it the correct name later del dg_tens_meta[temp_name] if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT: # The KS value used by compliance verification is altered when the # bias data is non-zero if max(abs(data)) > 0.0: argsDict["ksb"] = argsDict["ks"] + 1 if testGen.args.lazy_data_gen: data = None if tens_meta["input_type"] == "VARIABLE": tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data) else: tens = testGen.ser.addConst(shape, dtypeList[idx], data) tens_ser_list.append(tens) # Add the meta data to the list using the serializer tensor name dg_tens_meta[tens.name] = tens_meta return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data) @staticmethod def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: # Integer test op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 1 and cCount == 0 ), "Op.NEGATE must have 1 placeholders, 0 consts" # Must create tensors with values within accumulator (int32) negatable # range max_val = (1 << 31) - 1 min_val = -max_val arr = np.int32( testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0]) ) tens_ser_list = [] tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: # ERROR_IF or floating point test return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) # Set the ADD/SUB data range to half the largest value to avoid infinities TVG_FLOAT_HIGH_VALUE_ADDSUB = { DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2), DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2), DType.BF16: (TVG_FLOAT_HIGH_VALUE[DType.BF16] / 2), } @staticmethod def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None: # Make sure the integer operation does not cause value saturation - where # the number wraps due to limited number of bits to store the answer op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts" tens_ser_list = [] add = op["op"] in (Op.ADD, Op.ADD_SHAPE) data_range = testGen.args.tensor_shape_range a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range) b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range) if add: res_arr = np.add(a_arr, b_arr, dtype=np.int64) else: res_arr = np.subtract(a_arr, b_arr, dtype=np.int64) # Work out the saturation limits max_i32 = (1 << 31) - 1 min_i32 = -(1 << 31) max_arr = np.full(shapeList[1], max_i32) min_arr = np.full(shapeList[1], min_i32) # Find how much values exceed the maximum/minimums sat_max_arr = np.maximum(res_arr - max_arr, 0) sat_min_arr = np.minimum(res_arr - min_arr, 0) if not add: # Swap saturation values and negate values as we need to perform opposite operations sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr # Create new array of unsaturated values by clipping values as needed b_unsat_arr = b_arr if (sat_max_arr != 0).any(): # Clip values that cause saturation b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32) # Reduce axes in unsaturated tensor to match original tensor for axis, dim in enumerate(b_arr.shape): if dim != b_unsat_arr.shape[axis]: assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True) if (sat_min_arr != 0).any(): # Clip values that cause saturation b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32) # Reduce axes in unsaturated tensor to match original tensor for axis, dim in enumerate(b_arr.shape): if dim != b_unsat_arr.shape[axis]: assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True) tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) ) tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: # ERROR_IF or floating point test data_range = TosaTensorValuesGen._get_data_range( testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_ADDSUB ) if data_range: argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgCondIfWhileLoop( testGen, op, dtypeList, shapeList, testArgs, error_name=None ): if dtypeList[0] in ( DType.INT32, DType.INT16, DType.INT8, ): # Limit input tensors with cond_if_binary or while_loop to stop # saturation of add/sub ops with int32 and keep all logical shift # values between 0 to 31 for int16 or int8 pCount, cCount = op["operands"] pRemain = pCount placeholders = [] for idx, shape in enumerate(shapeList[:]): if dtypeList[0] == DType.INT32: arr = testGen.getRandTensor(shapeList[idx], DType.INT16) else: arr = np.int32( testGen.rng.integers(low=0, high=32, size=shapeList[idx]) ) if pRemain > 0: placeholders.append( testGen.ser.addPlaceholder(shape, dtypeList[idx], arr) ) pRemain -= 1 else: placeholders.append( testGen.ser.addConst(shape, dtypeList[idx], arr) ) return placeholders else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgArithmeticRightShift( testGen, op, dtypeList, shapeList, testArgs, error_name=None ): pCount, cCount = op["operands"] # Force value of operand[1] to be within [0, num_bits] assert ( pCount == 2 and cCount == 0 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts" placeholders = [] for idx, shape in enumerate(shapeList[:]): if idx == 1: if dtypeList[idx] == DType.INT8: arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape)) elif dtypeList[idx] == DType.INT16: arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape)) elif dtypeList[idx] == DType.INT32: arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape)) elif error_name == ErrorIf.WrongInputType: arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape)) else: raise Exception("OpArithmeticRightShift: invalid input dtype") else: arr = testGen.getRandTensor(shape, dtypeList[idx]) placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)) return placeholders @staticmethod def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["new_shape"])] # Create a new list for the pre-generated data in argsDict["fixed_data"] argsDict["fixed_data"] = [None, argsDict["new_shape"]] return TosaTensorValuesGen.tvgLazyGenDefault( testGen, op, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgPad(testGen, op, dtypeList, shapeList, argsDict, error_name=None): # argsDict["pad"] is 2D array, need to flatten it to get list of values pad_values = argsDict["pad"].flatten() dtypeList[1] = DType.SHAPE shapeList[1] = [len(pad_values)] # Create a new list for the pre-generated data in argsDict["fixed_data"] argsDict["fixed_data"] = [None, pad_values] return TosaTensorValuesGen.tvgLazyGenDefault( testGen, op, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgSlice(testGen, op, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["start"])] dtypeList[2] = DType.SHAPE shapeList[2] = [len(argsDict["size"])] # Create a new list for the pre-generated data in argsDict["fixed_data"] argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]] return TosaTensorValuesGen.tvgLazyGenDefault( testGen, op, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["multiples"])] argsDict["fixed_data"] = [None, argsDict["multiples"]] return TosaTensorValuesGen.tvgLazyGenDefault( testGen, op, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): # Set datatype of condition tensor to boolean dtypeList[0] = DType.BOOL return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgIntDiv(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): if error_name is None: op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.INTDIV must have 2 placeholders, 0 consts" tens_ser_list = [] # Two invalid cases for Op.INTDIV: # 1. divisor == 0 # 2. dividend == -(1<<31) and divisor == -1 while True: dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1]) if (divisor_arr == 0).any(): continue if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any(): continue break tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr) ) tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) # Set the MUL data range to the square root of the largest value # to avoid infinities TVG_FLOAT_HIGH_VALUE_MUL = { DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]), DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]), DType.BF16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.BF16]), } @staticmethod def tvgMul(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): if error_name is not None or dtypeList[0] in ( DType.FP16, DType.BF16, DType.FP32, ): # ERROR_IF or floating point test data_range = TosaTensorValuesGen._get_data_range( testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_MUL ) if data_range: argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) else: # Integer test op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.MUL must have 2 placeholders, 0 consts" tens_ser_list = [] # Make sure multiply result in int32 range if dtypeList[0] == DType.SHAPE: shift = 0 else: shift = argsDict["shift"] if dtypeList[0] == DType.INT8: num_bits = 8 elif dtypeList[0] == DType.INT16: num_bits = 16 elif dtypeList[0] in (DType.INT32, DType.SHAPE): num_bits = 32 elif error_name == ErrorIf.WrongInputType: num_bits = 8 else: raise Exception("OpMul: invalid input dtype") for idx, shape in enumerate(shapeList[:]): if dtypeList[idx] == DType.SHAPE: low = testGen.args.tensor_shape_range[0] high = testGen.args.tensor_shape_range[1] else: low = -(2 ** (num_bits - 1)) high = (2 ** (num_bits - 1)) - 1 a_arr = np.int32( testGen.rng.integers(low=low, high=high, size=shapeList[0]) ) b_arr = np.int32( testGen.rng.integers(low=low, high=high, size=shapeList[1]) ) i = 0 while True: a_arr_64 = a_arr.astype(np.int64) b_arr_64 = b_arr.astype(np.int64) if shift > 0: rounding = 1 << (shift - 1) result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift else: result_arr = a_arr_64 * b_arr_64 if (result_arr > -(2**31)).all() and ( result_arr <= ((2**31) - 1) ).all(): break i = i + 1 a_arr = a_arr // 2 b_arr = b_arr // 2 if dtypeList[0] == DType.SHAPE: tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64) ) tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64) ) else: tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) ) tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) @staticmethod def tvgConcat(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): count = len(shapeList) - testGen.args.num_const_inputs_concat if count < 1: count = 1 if testGen.args.num_const_inputs_concat == 0: count = len(shapeList) op = testGen.TOSA_OP_LIST[opName] if op["op"] == Op.CONCAT_SHAPE: # Set the axis to 0 shapeList = TosaTensorGen.tgConcatConstInput( testGen, shapeList, 0, error_name ) else: shapeList = TosaTensorGen.tgConcatConstInput( testGen, shapeList, argsDict["axis"], error_name ) # Override default pCount/cCount for operator argsDict["p_count"] = count argsDict["c_count"] = len(shapeList) - count return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgLogicalShift( testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts" values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1])) tens_ser_list = [] tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr) ) tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) @staticmethod def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]): # Integer op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.EQUAL must have 2 placeholders, 0 consts" a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1]) # Using random numbers means that it will be very unlikely that # there are any matching (equal) values, therefore force that # there are twice the number of matching values as the tensor rank for num in range(0, len(shapeList[0]) * 2): a_index = [] b_index = [] # Choose an index in each axis for the whole shape for axis in range(0, len(shapeList[0])): # Index can be up to the largest dimension in both shapes index = np.int32( testGen.rng.integers( 0, max(shapeList[0][axis], shapeList[1][axis]) ) ) # Reduce the index down to a shape's dim for broadcasting a_index.append(min(shapeList[0][axis] - 1, index)) b_index.append(min(shapeList[1][axis] - 1, index)) a_arr[tuple(a_index)] = b_arr[tuple(b_index)] tens_ser_list = [] tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) ) tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: # ERROR_IF or floating point test return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): dtype = dtypeList[0] if dtype == DType.INT32: op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 1 and cCount == 0 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts" # Limit values so that the sum cannot exceed the range of an int32 during # summation of any axis range_val = int((1 << 31) / max(shapeList[0])) values_arr = np.int32( testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0]) ) tens_ser_list = [] tens_ser_list.append( testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: # ERROR_IF or dot product floating point test if ( error_name is None and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT ): # Limit ranges for (non error & non compliance) tests by using # values that can be summed on any axis to not hit infinity highval_lookup = { dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype] / max(shapeList[0]) } data_range = TosaTensorValuesGen._get_data_range( testGen, dtype, highval_lookup ) assert data_range is not None argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgReduceProduct( testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): dtype = dtypeList[0] if error_name is None: # Limit ranges for (non error) tests by using # values that can be multiplied on any axis to not hit infinity highval_lookup = { dtype: math.pow( TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / max(shapeList[0]), ) } data_range = TosaTensorValuesGen._get_data_range( testGen, dtype, highval_lookup ) assert data_range is not None argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) # Set the POW exponent high data range TVG_FLOAT_HIGH_VALUE_POW_EXP = { DType.FP32: 10.0, DType.FP16: 10.0, DType.BF16: 10.0, } # POW highest base value (within a safe margin of error) that can be raised # to +ve exponent that doesn't become Infinity TVG_FLOAT_HIGH_VALUE_POW_BASE = { DType.FP32: math.floor( math.pow( TVG_FLOAT_HIGH_VALUE[DType.FP32], 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32], ) ), DType.FP16: math.floor( math.pow( TVG_FLOAT_HIGH_VALUE[DType.FP16], 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16], ) ), DType.BF16: math.floor( math.pow( TVG_FLOAT_HIGH_VALUE[DType.BF16], 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16], ) ), } # POW lowest base value (within a safe margin of error) that can be raised # to -ve exponent that doesn't become Infinity TVG_FLOAT_LOW_VALUE_POW_BASE = { DType.FP32: math.ceil( math.pow( 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32], 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32], ) * 1000 ) / 1000, DType.FP16: math.ceil( math.pow( 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16], 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16], ) * 1000 ) / 1000, DType.BF16: math.ceil( math.pow( 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16], 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16], ) * 1000 ) / 1000, } @staticmethod def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): if error_name is not None: return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) dtype = dtypeList[0] # Different ranges for POW test_set = argsDict["s"] if test_set == 0: # Positive base with fractional exponent base_range = TosaTensorValuesGen._get_data_range( testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE, TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE, ) exp_range = TosaTensorValuesGen._get_data_range( testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP ) exp_round = False else: # Integer exponent exp_range = TosaTensorValuesGen._get_data_range( testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP ) exp_round = True if test_set == 1: # Positive base base_range = TosaTensorValuesGen._get_data_range( testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE, TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE, ) else: assert test_set == 2 # Negative base # Supply new look up tables with negative values base_range = TosaTensorValuesGen._get_data_range( testGen, dtype, {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]}, {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]}, ) data_range_list = ( { "range": base_range, }, { "range": exp_range, "round": exp_round, }, ) argsDict["data_range_list"] = data_range_list return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): # LOG & RSQRT data range from lowest expressible positive number to # largest to avoid NaNs data_range = TosaTensorValuesGen._get_data_range( testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE, TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE, ) if data_range: argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) # Set the EXP data range to the log of the largest to smallest values # to avoid infinities or making the result zero TVG_FLOAT_HIGH_VALUE_EXP = { DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]), DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]), DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]), } TVG_FLOAT_LOW_VALUE_EXP = { DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]), DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]), DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]), } @staticmethod def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): data_range = TosaTensorValuesGen._get_data_range( testGen, dtypeList[0], TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP, TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP, ) if data_range: argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgFullyConnected( testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): dtype = dtypeList[0] if ( error_name is None and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT and dtype in (DType.BF16,) ): # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance # Limit ranges for (non error & non compliance) FP tests by using # values that can be multiplied on any axis to not hit infinity/NaN IC = shapeList[0][1] highval_lookup = { dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC) } data_range = TosaTensorValuesGen._get_data_range( testGen, dtype, highval_lookup ) assert data_range is not None argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): in_dtype = dtypeList[0] out_dtype = argsDict["out_type"] # Create look up to limit input tensor to output type maximums to avoid # FP infinities and saturation of integers out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True) highval_lookup = {in_dtype: out_range[1]} data_range = TosaTensorValuesGen._get_data_range( testGen, in_dtype, highval_lookup, ) assert data_range is not None argsDict["data_range"] = data_range return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): K = shapeList[0][1] # Fix the type of the indices tensor dtypeList[1] = DType.INT32 dtype = dtypeList[0] if not gtu.dtypeIsSupportedByCompliance(dtype): # Test unsupported by data generator op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.GATHER must have 2 placeholders, 0 consts" tens_ser_list = [] for idx, shape in enumerate(shapeList): dtype = dtypeList[idx] if idx != 1: arr = testGen.getRandTensor(shape, dtype) tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr)) else: # Limit data range of indices tensor upto K (exclusive) arr = testGen.getRandTensor(shape, dtype, (0, K)) # To match old functionality - create indices as CONST tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr)) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: # ERROR_IF or floating point test # Use inclusive values upto index K for indices tensor data_range_list = ( {"range": None}, {"range": (0, K - 1)}, ) argsDict["data_range_list"] = data_range_list return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): K = shapeList[0][1] W = shapeList[2][1] # Work out an indices tensor here with data that doesn't exceed the # dimension K of the values_in tensor and does NOT repeat the same K # location as needed by the spec: # "It is not permitted to repeat the same output index within a single # SCATTER operation and so each output index occurs at most once." assert K >= W, "Op.SCATTER W must be smaller or equal to K" # Fix the type of the indices tensor dtypeList[1] = DType.INT32 dtype = dtypeList[0] if not gtu.dtypeIsSupportedByCompliance(dtype): # Test unsupported by data generator op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] assert ( pCount == 3 and cCount == 0 ), "Op.SCATTER must have 3 placeholders, 0 consts" tens_ser_list = [] for idx, shape in enumerate(shapeList): dtype = dtypeList[idx] if idx != 1: arr = testGen.getRandTensor(shape, dtype) tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr)) else: # Create the indices array assert dtype == DType.INT32, "Op.SCATTER unexpected indices type" arr = [] for n in range(shape[0]): # Get a shuffled list of output indices (0 to K-1) and # limit length to W arr.append(testGen.rng.permutation(K)[:W]) indices_arr = np.array(arr, dtype=np.int32) # (N, W) # To match old functionality - create indices as CONST tens_ser_list.append( testGen.ser.addConst(shape, dtype, indices_arr) ) return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: # ERROR_IF or floating point test # Use inclusive values upto index K for indices tensor data_range_list = ( {"range": None}, {"range": (0, K - 1)}, {"range": None}, ) argsDict["data_range_list"] = data_range_list return TosaTensorValuesGen.tvgLazyGenDefault( testGen, opName, dtypeList, shapeList, argsDict, error_name ) class TosaArgGen: """Argument generators create exhaustive or random lists of attributes for operators that take attributes or other parameters. The return value is a list of (descriptive_name, [arglist]) tuples where the descriptive_name is appended to the test name and the arglist is expanded as arguments to the operator build function. """ def __init__(self): pass @staticmethod def _add_data_generators(testGen, opName, dtype, arg_list, error_name): """Add extra tests for each type of data generator for this op.""" if ( error_name is None and "data_gen" in testGen.TOSA_OP_LIST[opName] and gtu.dtypeIsSupportedByCompliance(dtype) ): if dtype in [DType.FP16, DType.FP32, DType.BF16]: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] else: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] else: # Error test or No data generator types listed - assume random dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,) # Expand arg list with other data generator types new_arg_list = [] for dg_type in dataGenTypesList: for arg_str, args_dict in arg_list: args_dict["dg_type"] = dg_type if dg_type == gtu.DataGenType.PSEUDO_RANDOM: if error_name is None: num_test_sets = ( args_dict["num_test_sets"] if "num_test_sets" in args_dict else 0 ) else: num_test_sets = 0 elif dg_type == gtu.DataGenType.DOT_PRODUCT: # Extra tests for each dot product test set dot_products = args_dict["dot_products"] if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN: shape_info = ( " ({})".format(testGen.shapeStr(args_dict["shape"])) if "shape" in args_dict else "" ) print( f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}" ) continue # KS and acc_type is required by all dot product generators assert "ks" in args_dict assert "acc_type" in args_dict num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS if num_test_sets > 0: for s in range(0, num_test_sets): new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}" new_args_dict = args_dict.copy() new_args_dict["s"] = s new_arg_list.append((new_arg_str, new_args_dict)) else: # Default is a single test new_arg_list.append((arg_str, args_dict)) return new_arg_list @staticmethod def agNone(testGen, opName, shapeList, dtype, error_name=None): """A trivial argument generator for operators that don't take any non-tensor arguments""" arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, [("", {})], error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agPow(testGen, opName, shapeList, dtype, error_name=None): """Pow operator needs different test sets to cover random numbers without creating NaNs or Infs""" arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, [("", {"num_test_sets": 3})], error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agAxis(testGen, opName, shapeList, dtype, error_name=None): """Build the axis argument for operators that take a single axis""" arg_list = [] shape = shapeList[0] if error_name == ErrorIf.AxisSmallerZero: # Set too small axis axes = [testGen.rng.integers(-5, 0)] elif error_name == ErrorIf.AxisLargerRank: # Set too large axis axes = [testGen.rng.integers(len(shape) + 1, len(shape) + 10)] else: # Create tests for each dimension axes = range(0, len(shape)) opid = testGen.TOSA_OP_LIST[opName]["op"] for a in axes: args_dict = {"axis": int(a)} if opid == Op.REDUCE_SUM: args_dict["dot_products"] = gtu.product(shape) args_dict["shape"] = shape args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32 arg_list.append(("axis{}".format(a), args_dict)) arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def _calculate_sparsity(num_tests, sparsity_factor): sparsity = num_tests // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 # To get a variety of parameter combinations sparsity should not be a # multiple of 2, 3 or 5 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: sparsity += 1 return sparsity @staticmethod def agConv(testGen, opName, shapeList, dtypes, error_name=None): # Used by CONV2D, CONV3D and DEPTHWISE_CONV2D arg_list = [] if testGen.args.level8k and error_name is not None: # Don't produce negative large tests return arg_list # Shape: Batches, (Depth), Height, Width, Channels ifm_shape = shapeList[0] # Shape: (OFM channels), (KD), KH, KW, IFM channels filter_shape = shapeList[1] accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) # Op type checks conv3d = opName.startswith("conv3d") depthwise = opName.startswith("depthwise") # Check the rank rank = 5 if conv3d else 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == rank assert len(filter_shape) == rank # kernel rank omits channels k_rank = rank - 2 k_pos = 0 if depthwise else 1 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)]) # compliance size - KS k_size = gtu.product(k_shape) if not depthwise: k_size *= ifm_shape[-1] if not testGen.args.level8k: # Generate comprehensive argument lists # - except for named errors, which use specific invalid value(s) if error_name == ErrorIf.PadSmallerZero: p_vals = [testGen.rng.choice(range(-5, 0))] else: p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)] paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))} if error_name == ErrorIf.StrideSmallerOne: # Can't use stride=0, as it is used to derive output shape, as a divisor s_vals = [testGen.rng.choice(range(-5, 0))] else: # Stride must be greater than 1 to force non-integer error startStride = ( 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2 ) s_vals = [ x for x in range(startStride, testGen.args.max_conv_stride + 1) ] strides = {x for x in itertools.product(*([s_vals] * k_rank))} if error_name == ErrorIf.DilationSmallerOne: d_vals = [testGen.rng.choice(range(-5, 1))] else: d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)] dilations = {x for x in itertools.product(*([d_vals] * k_rank))} if not error_name and testGen.args.oversize: # add some oversize argument values if max(ifm_shape) < 64: bigPadding = 9 paddings.update( { x for x in itertools.product( *([[0, bigPadding]] * (k_rank * 2)) ) } ) bigStride = 8 strides.update( {x for x in itertools.product(*([[1, bigStride]] * k_rank))} ) bigDilation = 7 dilations.update( {x for x in itertools.product(*([[1, bigDilation]] * k_rank))} ) max_dim_size = None # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 120 sparsity = TosaArgGen._calculate_sparsity( len(paddings) * len(strides) * len(dilations), sparsity_factor ) else: # Only test 8k levels boundaries bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL bigPadding = bigKernel dilation_shape = [1] * k_rank pad_shape = [0] * k_rank * 2 if conv3d: # Small stride apart from for big kernel (see below) to keep # tensor size/calculation small stride_shape = [1] * k_rank for idx in range(k_rank): pad_offset = idx * 2 if k_shape[idx] == bigKernel: # Padding shape needs to account for tensor shape pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1] pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1 # Big stride to reduce output size stride_shape[idx] = bigKernel else: # Account for kernel size pad_shape[pad_offset] = k_shape[idx] - 1 else: # Always have a large stride with extra padding and dilation to keep # tensor calculation reasonable stride_shape = [bigKernel] * k_rank for idx in range(k_rank): # Dilation shape must account for kernel size dilation_shape[idx] = bigKernel // k_shape[idx] # Padding shape needs to accommodate tensor/kernel & dilation pad_offset = idx * 2 pad_shape[pad_offset] = bigPadding - ifm_shape[idx + 1] pad_shape[pad_offset + 1] = bigPadding - dilation_shape[idx] + 1 strides = {tuple(stride_shape)} dilations = {tuple(dilation_shape)} paddings = {tuple(pad_shape)} # Create a limit for the output dimensions size max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL # Currently allow all combinations that are reasonable size sparsity = 1 n = 0 for s in sorted(list(strides)): for p in sorted(list(paddings)): for d in sorted(list(dilations)): if ( n % sparsity == 0 # the padded shape must exceed the dilation * kernel to get a positive # sized output shape and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1) and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1) and ( k_rank < 3 or ( (ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k_shape[2] - 1) ) ) ): remainders = [] outputs = [] for index in range(k_rank): pad_offset = index * 2 partial = ( ifm_shape[index + 1] - 1 + p[pad_offset] + p[pad_offset + 1] - (k_shape[index] - 1) * d[index] ) remainders.append(partial % s[index]) outputs.append((partial // s[index]) + 1) if ( # the parameters must produce integer exact output error_name != ErrorIf.ConvOutputShapeNonInteger and max(remainders) == 0 ) or ( error_name == ErrorIf.ConvOutputShapeNonInteger and max(remainders) > 0 ): if ( max_dim_size is not None and max(outputs) >= max_dim_size ): # Test will consume too much memory - skip it continue # Compliance - number of dot product calculations if depthwise: # N*OH*OW*C*M dots = gtu.product( (ifm_shape[0], *outputs, *filter_shape[2:]) ) else: # N*OH*OW*OC or N*OD*OH*OW*OC dots = gtu.product( (ifm_shape[0], *outputs, filter_shape[0]) ) args_dict = { "acc_type": accum_dtype, "stride": s, "pad": p, "dilation": d, "kernel": k_shape, "ks": k_size, "dot_products": dots, "shape": ifm_shape, } # Support for larger values than 9 needs different delimiter delim = "" if max(s + p + d) <= 9 else "x" arg_list.append( ( "acc{}_st{}_pad{}_dilat{}".format( testGen.typeStr(accum_dtype), delim.join([str(x) for x in s]), delim.join([str(x) for x in p]), delim.join([str(x) for x in d]), ), args_dict, ) ) n += 1 arg_list = TosaArgGen._add_data_generators( testGen, opName, dtypes[0], arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): assert isinstance(dtypes, (list, tuple)), f"{dtypes} unexpected" input_dtype = dtypes[0] if error_name == ErrorIf.WrongOutputType: accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype) elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect accum_dtype = DType.INT32 else: accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) # Set up compliance info args_dict = { "acc_type": accum_dtype, "ks": int(shapeList[0][1]), # Set KS = IC, from input A (N,IC) "dot_products": gtu.product((shapeList[0][0], shapeList[1][0])), "shape": shapeList[0], } arg_list = [(f"acc{testGen.typeStr(accum_dtype)}", args_dict)] arg_list = TosaArgGen._add_data_generators( testGen, opName, input_dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agMatMul(testGen, opName, shapeList, dtype, error_name=None): # Get valid accumulate type(s) if dtype == DType.INT8: accum_dtypes = [DType.INT32] elif dtype == DType.INT16: accum_dtypes = [DType.INT48] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] elif dtype == DType.BF16: accum_dtypes = [DType.FP32] elif dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" if error_name == ErrorIf.WrongOutputType: # Get incorrect output dtype for ErrorIf case accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect accum_dtypes = [DType.INT32] # Set up compliance info args_dict = { "ks": int(shapeList[0][2]), # Set KS = C, from input A (N,H,C) # Set dot_products = N*H*W "dot_products": gtu.product( (shapeList[0][0], shapeList[0][1], shapeList[1][2]) ), "shape": shapeList[0], } # Create arg tuple of string and dict arg_list = [] for a in accum_dtypes: d = args_dict.copy() d["acc_type"] = a arg_list.append((f"acc{testGen.typeStr(a)}", d)) arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] if testGen.args.level8k and error_name is not None: # Don't produce negative large tests return arg_list ifm_shape = shapeList[0] filter_shape = shapeList[1] accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) # Must be rank 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == 4 assert len(filter_shape) == 4 k_shape = tuple(filter_shape[1:3]) # compliance size - KS k_size = gtu.product((*k_shape, ifm_shape[3])) if not testGen.args.level8k: # Generate comprehensive argument lists # - except for named errors, which use specific invalid value(s) smallest_padding_size = -min(k_shape[0], k_shape[1]) + 1 if error_name == ErrorIf.PadLargerEqualKernel: max_filter_size = -max(k_shape[0], k_shape[1]) p_vals = [ testGen.rng.choice(range(max_filter_size - 10, max_filter_size)) ] else: p_vals = [ x for x in range( smallest_padding_size, testGen.args.max_conv_padding + 1 ) ] paddings = {x for x in itertools.product(*([p_vals] * 4))} if error_name == ErrorIf.StrideSmallerOne: # Can't use stride=0, as it is used to derive output shape, as a divisor s_vals = [testGen.rng.choice(range(-5, 0))] else: s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * 2))} if not error_name and testGen.args.oversize: # add some oversize argument values if max(ifm_shape) < 64: bigPadding = 9 paddings.update( { x for x in itertools.product( *([[smallest_padding_size, bigPadding]] * 4) ) } ) bigStride = 8 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))}) # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 10 sparsity = len(paddings) * len(strides) // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 # To get a variety of parameter combinations sparsity should not be a # multiple of 2, 3 or 5 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: sparsity += 1 else: # Only test 8k levels boundaries bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL bigPadding = bigKernel pad_shape = [0] * (len(k_shape) * 2) stride_shape = [1] * len(k_shape) # The point at which input dimension combined with the stride will # create large output sizes! LARGE_SIZE = 2 for idx in range(len(k_shape)): pad_offset = idx * 2 if k_shape[idx] == bigKernel: # Set large stride stride_shape[idx] = bigKernel # Use negative output padding to reduce shape size pad_shape[pad_offset] = -(bigPadding - 1) if ifm_shape[idx + 1] > LARGE_SIZE: pad_shape[pad_offset + 1] = -(bigPadding - 1) else: # The other dimension should be the bigKernel alt_idx = 1 - idx if ( k_shape[alt_idx] == bigKernel and ifm_shape[alt_idx + 1] < LARGE_SIZE ): # As the input is small, the large stride won't # affect the output so we can add some padding pad_shape[pad_offset + 1] = bigPadding strides = {tuple(stride_shape)} paddings = {tuple(pad_shape)} # Currently allow all combinations that are reasonable size sparsity = 1 n = 0 for s in sorted(list(strides)): for p in sorted(list(paddings)): if n % sparsity == 0: # Determine the output shape oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0] ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1] os = [ifm_shape[0], oh, ow, filter_shape[0]] # N*OH*OW*OC dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0])) args_dict = { "acc_type": accum_dtype, "stride": s, "pad": p, "kernel": k_shape, "ks": k_size, "dot_products": dots, "shape": ifm_shape, "out_shape": os, } # Support for larger values than 9 needs different delimiter delim = "" if max(s + p) <= 9 else "x" arg_list.append( ( "acc{}_st{}_pad{}_os{}".format( testGen.typeStr(accum_dtype), delim.join([str(x) for x in s]), delim.join([str(x) for x in p]), "x".join([str(x) for x in os]), ), args_dict, ) ) n += 1 arg_list = TosaArgGen._add_data_generators( testGen, opName, dtypes[0], arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agPad(testGen, opName, shapeList, dtype, error_name=None): rank = len(shapeList[0]) # Exhaustively test combinations of padding on each side of each dimension # - the range of padding values is defined by pad_min and pad_max # - for padding >9, the name format needs to be more distinctive pad_min, pad_max = 0, 1 pad_values = [x for x in range(pad_min, pad_max + 1)] if error_name == ErrorIf.PadSmallerZero: pad_values = [x for x in range(-2, 0)] axis_pad_values = [x for x in itertools.product(pad_values, pad_values)] shape_pad_values = itertools.product(*([axis_pad_values] * rank)) if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 elif dtype in (DType.FP16, DType.BF16, DType.FP32): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: return [] list_shape_pad_values = list(shape_pad_values) # If we are producing tests for rank 6 or greater use sparsity if len(list_shape_pad_values) > 1024: sparsity_factor = 2 if error_name else 120 sparsity = TosaArgGen._calculate_sparsity( len(list_shape_pad_values), sparsity_factor ) else: sparsity = 1 # Build arg list arg_list = [] for n, paddings in enumerate(list_shape_pad_values): paddings = list(paddings) args_valid = True if error_name == ErrorIf.PadSmallerZero: # Prevent negative output shapes while ensuring still testing for negative padding for i in range(rank): dim_after_padding = ( paddings[i][0] + paddings[i][1] + shapeList[0][i] ) if dim_after_padding < 1: paddings[i] = (0, 0) if all([p > -1 for p in paddings[i]]): args_valid = False if args_valid and n % sparsity == 0: name = "pad" for r in range(rank): before, after = paddings[r] name = f"{name}{before}{after}" args_dict = { "pad": np.array(paddings), "pad_const_int": pad_const_int, "pad_const_fp": pad_const_fp, } arg_list.append((name, args_dict)) if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0: warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}") arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agPooling(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] shape = shapeList[0] if error_name != ErrorIf.WrongRank: assert len(shape) == 4 test_level8k = testGen.args.level8k and error_name is None startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2 startKernel = 2 startPad = 0 if not test_level8k: # Generate comprehensive argument lists p_vals = [x for x in range(startPad, testGen.args.max_pooling_padding + 1)] paddings = {x for x in itertools.product(*([p_vals] * 4))} # Stride must be greater than 1 to force non-integer error s_vals = [ x for x in range(startStride, testGen.args.max_pooling_stride + 1) ] strides = {x for x in itertools.product(*([s_vals] * 2))} k_vals = [ x for x in range(startKernel, testGen.args.max_pooling_kernel + 1) ] kernels = {x for x in itertools.product(*([k_vals] * 2))} max_dim_size = None else: # Only test 8k levels bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE bigKernel = testGen.TOSA_8K_LEVEL_MAX_KERNEL strides = {(1, bigStride), (bigStride, 4)} kernels = {(1, bigKernel), (bigKernel, 3)} paddings = set() for s in sorted(list(strides)): for k in sorted(list(kernels)): padding = [] for idx in range(len(k)): total_padding = s[idx] - shape[idx + 1] + k[idx] while total_padding < 0: # Must meet: shape + padding > kernel total_padding += s[idx] if total_padding < k[idx]: padding.extend([0, total_padding]) else: # Note this may produce padding >= k[idx] which is not # allowed - but will be ignored in the creation loop below padding.extend([k[idx] - 1, total_padding - (k[idx] - 1)]) paddings.add(tuple(padding)) # Create a limit for the output dimensions size max_dim_size = testGen.TOSA_8K_LEVEL_MAX_KERNEL if opName == "max_pool2d": accum_dtypes = [None] # max_pool has no accumulate dtype elif dtype == DType.INT8 or dtype == DType.INT16: accum_dtypes = [DType.INT32] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] elif dtype == DType.BF16 or dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" else: # Set to something for the ErrorIf case which has # incorrect input data-type accum_dtypes = [DType.INT32] if not test_level8k: if testGen.args.oversize: # add some oversize argument values bigStride = 7 bigKernel = 9 strides.update( {x for x in itertools.product(*([[startStride, bigStride]] * 2))} ) kernels.update( {x for x in itertools.product(*([[startKernel, bigKernel]] * 2))} ) if max(shape) < 64: # padding must be less than the kernel size bigPadding = bigKernel - 1 paddings.update( {x for x in itertools.product(*([[startPad, bigPadding]] * 4))} ) # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 500 sparsity = ( len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 ) else: # We have already limited test output combinations for 8k tests sparsity = 1 arg_str = ( "acc{}_st{}_kern{}_pad{}" if accum_dtypes[0] is not None else "st{}_kern{}_pad{}" ) def get_arg_list_element(accum, stride, pad, kern, dot_products=0, shape=[]): # Return tuple containing the formatted argument string and # the corresponding argument values in a dictionary # Support for larger values than 9 needs different delimiter delim = "" if max(stride + kern + pad) <= 9 else "x" arg_str_elems = [ delim.join([str(x) for x in stride]), delim.join([str(x) for x in kern]), delim.join([str(x) for x in pad]), ] args_dict = { "stride": stride, "pad": pad, "kernel": kern, "dot_products": dot_products, # Ignored for error tests "shape": shape, "ks": gtu.product(kern), # avg_pool2d: KS = KX*KY } if accum is not None: arg_str_elems.insert(0, testGen.typeStr(accum)) args_dict["acc_type"] = accum return (arg_str.format(*arg_str_elems), args_dict) n = 0 for a in accum_dtypes: for s in sorted(list(strides)): for p in sorted(list(paddings)): for k in sorted(list(kernels)): if error_name in [ ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel, ]: sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( testGen, error_name, s, p, k ) if None not in [sNew, pNew, kNew] and n % sparsity == 0: arg_list.append( get_arg_list_element(a, sNew, pNew, kNew, shape) ) elif ( n % sparsity == 0 # padding must not exceed the kernel size and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1] # the padded shape must exceed the kernel size and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1] ): partial_h = shape[1] + p[0] + p[1] - k[0] partial_w = shape[2] + p[2] + p[3] - k[1] remainder_h = partial_h % s[0] remainder_w = partial_w % s[1] output_h = partial_h // s[0] + 1 output_w = partial_w // s[1] + 1 # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w) if ( # the parameters must produce integer exact output error_name != ErrorIf.PoolingOutputShapeNonInteger and remainder_h == 0 and remainder_w == 0 ) or ( error_name == ErrorIf.PoolingOutputShapeNonInteger and (remainder_h != 0 or remainder_w != 0) ): if ( max_dim_size is not None and max(output_h, output_w) > max_dim_size ): # Test will consume too much memory - skip it continue # Dot products = N*OH*OW*C dp = gtu.product( (shape[0], output_h, output_w, shape[3]) ) arg_list.append( get_arg_list_element(a, s, p, k, dp, shape) ) n += 1 # Now add data generator types arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agCast(testGen, opName, shapeList, inDtype, error_name=None): arg_list = [] # Enumerate the output types here if error_name == ErrorIf.WrongOutputType: dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype) elif inDtype == DType.INT8: dtypeList = [ DType.BOOL, DType.INT16, DType.INT32, DType.FP16, DType.BF16, DType.FP32, ] elif inDtype == DType.INT16: dtypeList = [ DType.BOOL, DType.INT8, DType.INT32, DType.FP16, DType.BF16, DType.FP32, ] elif inDtype == DType.INT32: dtypeList = [ DType.BOOL, DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32, ] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] elif inDtype == DType.BF16: dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] elif inDtype == DType.FP32: dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] else: raise Exception("Unexpected input dtype: {}".format(inDtype)) for dtype in dtypeList: arg_list.append( ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype}) ) # Now add data generator types arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) return arg_list @staticmethod def agRescale(testGen, opName, shapeList, inDtype, error_name=None): arg_list = [] # Enumerate the output types here for outDtype in [ DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.UINT16, ]: if ( outDtype in [DType.UINT8, DType.INT8, DType.UINT16] and error_name == ErrorIf.OutputZeroPointNotZero ): continue if ( outDtype != DType.UINT16 and error_name == ErrorIf.U16OutputZeroPointNotValid ) or ( inDtype != DType.UINT16 and error_name == ErrorIf.U16InputZeroPointNotValid ): # ErrorIfs only valid with UINT16 continue if ( inDtype == DType.UINT8 and outDtype not in [DType.INT8, DType.INT16] and error_name != ErrorIf.WrongOutputType ): # The only output dtypes for UINT8 are INT8/INT16, skip all others continue if ( inDtype not in [DType.INT8, DType.INT16] and outDtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType ): # The only input dtypes for UINT8 are INT8/INT16, skip all others continue if ( inDtype == DType.UINT16 and outDtype != DType.INT16 and error_name != ErrorIf.WrongOutputType ): # The only output dtype for UINT16 is INT16, skip all others continue if ( inDtype != DType.INT16 and outDtype == DType.UINT16 and error_name != ErrorIf.WrongOutputType ): # The only input dtype for UINT16 is INT16, skip all others continue if ( error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype) ): continue for scale32 in [False, True]: if error_name == ErrorIf.ScaleTrue and not scale32: continue elif error_name == ErrorIf.ScaleNotTrue and scale32: continue for double_round in [False, True]: if error_name == ErrorIf.ScaleNotTrue and not double_round: continue for per_channel in [False, True]: if ( inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue ): # Illegal condition. Must be scale32=False continue if ( double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue ): # Illegal condition. ERROR_IF(!scale32 && double_round) continue arg_list.append( ( "out{}_sc{}_dr{}_pc{}".format( testGen.typeStr(outDtype), int(scale32), int(double_round), int(per_channel), ), [outDtype, scale32, double_round, per_channel], ) ) return arg_list @staticmethod def agMul(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] if dtype is DType.INT32: for p in range(testGen.args.num_rand_permutations): shift = testGen.randInt(0, 32) arg_list.append(("perm{}_shift{}".format(p, shift), {"shift": shift})) else: arg_list.append(("perm0_shift0", {"shift": 0})) arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] arg_list.append(("roundTrue", [True])) arg_list.append(("roundFalse", [False])) return arg_list @staticmethod def agFFT2d(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] shape = shapeList[0] dot_products = gtu.product(shape) ks = 2 * shape[1] * shape[2] # 2*H*W for inverse in (True, False): args_dict = { "dot_products": dot_products, "shape": shape, "ks": ks, "acc_type": dtype, "inverse": inverse, } arg_list.append((f"inverse{inverse}", args_dict)) arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list # Helper function for reshape. Gets some factors of a larger number. @staticmethod def getFactors(val, start=1): factors = [] for i in range(start, int(np.sqrt(val)) + 1): if (val % i) == 0: factors.append(i) return factors @staticmethod def agReshape(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] origShape = shapeList[0] totalElements = gtu.product(origShape) factors = TosaArgGen.getFactors(totalElements) # Find new shapes up to the number of permutations asked for # This code is NOT fast. Fortunately, the numbers are fairly small. for p in range(testGen.args.num_rand_permutations): # Rank from 1 to TOSA_TENSOR_MAX_RANK newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1)) if len(factors) < newRank: continue # escape_counter limits the generation of new shapes to a reasonable time for escape_counter in range(100): # Generate the new shape of the chosen new rank newShape = [] remainingElements = totalElements shuffledFactors = testGen.rng.permutation(factors) for i in range(1, newRank): # pick rank-1 factors newShape.append(shuffledFactors[0]) remainingElements = remainingElements // shuffledFactors[0] shuffledFactors = testGen.rng.permutation( TosaArgGen.getFactors(remainingElements) ) newShape.append(remainingElements) # Check for duplicates duplicate = False for name, args_dict in arg_list: if args_dict["new_shape"] == newShape: duplicate = True break if not duplicate: outShape = "x".join([str(x) for x in newShape]) arg_list.append( ( "perm{}_rank{}_out{}".format(p, newRank, outShape), {"new_shape": newShape}, ) ) # Found an output shape for this permutation break # Now add data generator types arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) return arg_list @staticmethod def agTranspose(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] if error_name == ErrorIf.IndexOutsideBounds: incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1) incorrect_small_index = range(-len(ifm_shape), 0) permutations = [p for p in itertools.permutations(incorrect_large_index)] permutations.extend( [p for p in itertools.permutations(incorrect_small_index)] ) elif error_name == ErrorIf.IndexUsedTwice: # Create list with a duplicated index perm_range = list(range(len(ifm_shape))) index_choice = testGen.rng.choice(range(len(perm_range))) perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice] permutations = [p for p in itertools.permutations(perm_range)] else: # Get all permutations permutations = [p for p in itertools.permutations(range(len(ifm_shape)))] # Limit to possible permutations from shape dimension or argument setting limit = min(len(permutations), testGen.args.num_rand_permutations) # Get random permutation generator that uses all permutations random_permutations = testGen.rng.permutation(permutations) # Create list of required amount of permutations arg_list = [ ("perm{}".format(p), {"perms": random_permutations[p].tolist()}) for p in range(limit) ] # Now add data generator types arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agSlice(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] rank = len(ifm_shape) for p in range(testGen.args.num_rand_permutations): start = [] size = [] valid = True for i in range(rank): if ifm_shape[i] > 1: start.append(testGen.randInt(0, ifm_shape[i])) size.append(testGen.randInt(0, ifm_shape[i] - start[i])) # Invalid slice size? if size[i] == 0: valid = False else: start.append(0) size.append(1) if valid: # If ERROR_IF test required then incorrect start, size will be returned start, size = TosaErrorIfArgGen.eiSliceErrorIf( testGen, error_name, ifm_shape, start, size ) arg_list.append(("perm{}".format(p), {"start": start, "size": size})) # Now add data generator types arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agTile(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] rank = len(ifm_shape) for p in range(testGen.args.num_rand_permutations): # Pick a few random, but small multiple values # because otherwise this has a tendency to generate # enormous tensors multiples = [] for i in range(rank): if ifm_shape[i] > 1000: # Multiple of 1 if ifm_shape dimension is large to reduce # tensor size multiples.append(1) elif max(ifm_shape) > 1000: multiples.append(2) else: multiples.append(testGen.randInt(1, 4)) arg_list.append(("perm{}".format(p), {"multiples": multiples})) # Now add data generator types arg_list = TosaArgGen._add_data_generators( testGen, opName, dtype, arg_list, error_name, ) # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod def agResize(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] def get_aspect_ratio_resize_params(): common_aspect_ratios = ((3, 2), (16, 9), (4, 3)) aspect_ratio = testGen.rng.choice(common_aspect_ratios) invert = testGen.rng.choice((False, True)) letterbox = testGen.rng.choice((False, True)) scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1] scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0] scale_y_d = scale_x_d = 1 offset_x = offset_y = 0 if letterbox: max_border = scale_y_n border_y = testGen.randInt(low=0, high=max_border) border_x = 0 else: # Pillarboxing border_y = 0 max_border = scale_x_n border_x = testGen.randInt(low=0, high=max_border) scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d) offset = (offset_y, offset_x) border = (border_y, border_x) return scale, offset, border def get_upscale_downscale_params(): valid_params = False while not valid_params: upscale = testGen.rng.choice((False, True)) # True if sampling begins from (0,0). Otherwise (-0.5,-0.5) origin_sampling = testGen.rng.choice((False, True)) if upscale: shift = testGen.randInt(low=1, high=4) scale_x_d = scale_y_d = 1 scale_x_n = scale_y_n = ( 1 << shift if origin_sampling else 2 << shift ) border_x = border_y = 0 if origin_sampling else (1 << shift) - 1 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1 else: scale_x_n = 1 scale_y_n = 1 # Return list of valid scale_*_d values (max value 4) given input dim shape def get_valid_denom(ifm_dim): return [x for x in range(1, 5) if ifm_dim % x == 1] # Generate list of valid downscale values and choose one randomly valid_scale_y_ds = get_valid_denom(ifm_shape[1]) valid_scale_x_ds = get_valid_denom(ifm_shape[2]) if not valid_scale_y_ds and not valid_scale_x_ds: # Bad parameters, skip continue if not valid_scale_y_ds: scale_y_d = 1 else: scale_y_d = testGen.rng.choice(valid_scale_y_ds) if not valid_scale_x_ds: scale_x_d = 1 else: scale_x_d = testGen.rng.choice(valid_scale_x_ds) border_x = border_y = 0 offset_y = testGen.randInt(0, 16 * scale_y_n) offset_x = testGen.randInt(0, 16 * scale_x_n) valid_params = True scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d) offset = (offset_y, offset_x) border = (border_y, border_x) return scale, offset, border def get_rand_params(): def fix_scale_to_max_scale(scale_n, scale_d, max_scale): scale = scale_n / scale_d if scale > max_scale: factor = scale / max_scale new_scale_d = math.ceil(scale_d * factor) assert scale_n / new_scale_d <= max_scale scale_d = new_scale_d return scale_d # Scale scale_y_n = testGen.randInt(low=1, high=(1 << 11)) scale_x_n = testGen.randInt(low=1, high=(1 << 11)) scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n)) scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n)) scale_y_d = fix_scale_to_max_scale( scale_y_n, scale_y_d, testGen.TOSA_8K_LEVEL_MAX_SCALE ) scale_x_d = fix_scale_to_max_scale( scale_x_n, scale_x_d, testGen.TOSA_8K_LEVEL_MAX_SCALE ) # Offsets and border within the scale offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n)) offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n)) border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n) border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n) scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d) offset = (offset_y, offset_x) border = (border_y, border_x) return scale, offset, border def get_level_8k_params(): # Create 64x scale - 64/1 to 2048/32 scale_d = testGen.randInt( low=1, high=(1 << 11) / testGen.TOSA_8K_LEVEL_MAX_SCALE ) scale_n = scale_d * testGen.TOSA_8K_LEVEL_MAX_SCALE # Create half to fifth scaling scale_d_alt = testGen.randInt(low=2, high=6) scale_n_alt = 1 switch = testGen.rng.choice((False, True)) if switch: scale = (scale_n_alt, scale_d_alt, scale_n, scale_d) else: scale = (scale_n, scale_d, scale_n_alt, scale_d_alt) offset_y = testGen.rng.choice((-scale[0], 0, (16 * scale[0]) - 1)) offset_x = testGen.rng.choice((-scale[2], 0, (16 * scale[2]) - 1)) offset = (offset_y, offset_x) border_y = testGen.rng.choice((-16 * scale[0], 0, scale[0] - 1)) border_x = testGen.rng.choice((-16 * scale[2], 0, scale[2] - 1)) border = (border_y, border_x) return scale, offset, border for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]: # Exclude illegal {mode, type} configurations. Pick legal output types if mode == ResizeMode.NEAREST and dtype == DType.INT8: outputDTypeList = [DType.INT8] elif mode == ResizeMode.NEAREST and dtype == DType.INT16: outputDTypeList = [DType.INT16] elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: outputDTypeList = [DType.INT32] elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: outputDTypeList = [DType.INT48] elif dtype == DType.FP16: outputDTypeList = [DType.FP16] elif dtype == DType.BF16: outputDTypeList = [DType.BF16] elif dtype == DType.FP32: outputDTypeList = [DType.FP32] elif error_name == ErrorIf.WrongInputType: # If an incorrect input type is used then we set a 'correct' # output type to avoid other errors outputDTypeList = [DType.INT8, DType.INT16, DType.INT32] else: continue arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}" for outputDType in outputDTypeList: perm = 0 while perm < testGen.args.num_rand_permutations: # Random choice of type of params we are testing if not testGen.args.level8k: _rnd_param_fn = testGen.rng.choice( ( get_rand_params, get_upscale_downscale_params, get_aspect_ratio_resize_params, ) ) scale, offset, border = _rnd_param_fn() else: scale, offset, border = get_level_8k_params() # Expand params for bounds-checking (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale (offset_y, offset_x) = offset (border_y, border_x) = border # Make sure output dimensions OH and OW are integers partial_output_y = ( (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y ) partial_output_x = ( (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x ) if error_name == ErrorIf.ResizeOutputShapeNonInteger: # Look for non-integer test if ( partial_output_y % scale_y_d == 0 and partial_output_x % scale_x_d == 0 ): # Skip this test as it doesn't produce NonInteger output if perm > 0: perm += 1 continue else: # Alter the scaling factors to make the output integer while partial_output_y % scale_y_d != 0: scale_y_d -= 1 while partial_output_x % scale_x_d != 0: scale_x_d -= 1 # Make sure we are still within max scaling if ( scale_y_n / scale_y_d ) > testGen.TOSA_8K_LEVEL_MAX_SCALE or ( scale_x_n / scale_x_d ) > testGen.TOSA_8K_LEVEL_MAX_SCALE: # Skip the test as it is using too large a scaling factor if perm > 0: perm += 1 continue output_y = partial_output_y // scale_y_d + 1 output_x = partial_output_x // scale_x_d + 1 if ( output_y >= testGen.args.max_resize_output_dim or output_x >= testGen.args.max_resize_output_dim ) and error_name is None: # Skip positive test if output dim will be too high # Avoid high test latency and OOM issues if not testGen.args.level8k or perm > 0: perm += 1 continue if ( output_y <= 0 or output_y >= gtu.MAX_RESIZE_DIMENSION or output_x <= 0 or output_x >= gtu.MAX_RESIZE_DIMENSION ): # Output dimensions out of scope if error_name is not None and perm > 0: # As long as we have one ERROR_IF test, don't worry # about creating all the other permutations perm += 1 continue if error_name == ErrorIf.ResizeOutputShapeMismatch and ( ( output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION and output_y - scale_y_d < 1 ) or ( output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION and output_x - scale_x_d < 1 ) ): # Can't create a negative test with these params as it # will create invalid output size if perm > 0: perm += 1 continue scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d] offset = [offset_y, offset_x] border = [border_y, border_x] # Common for all data types if error_name is not None: ( scale, offset, border, outputDTypeNew, ) = TosaErrorIfArgGen.eiResizeErrorIf( testGen, error_name, mode, dtype, shapeList, outputDType, scale, offset, border, ) else: outputDTypeNew = outputDType arg_to_append = ( arg_str.format( "N" if mode == ResizeMode.NEAREST else "B", testGen.typeStr(outputDTypeNew), scale[0], scale[1], scale[2], scale[3], offset[0], offset[1], border[0], border[1], ), [ mode, scale, offset, border, dtype, outputDTypeNew, ], ) if arg_to_append in arg_list: # Skip already generated test params continue # Valid permutation perm += 1 arg_list.append(arg_to_append) return arg_list @staticmethod def agTable(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] if dtype == DType.INT8: table = np.int32( testGen.rng.integers(low=-128, high=128, size=[256]) ).tolist() else: # INT16 table = np.int32( testGen.rng.integers(low=-32768, high=32768, size=[513]) ).tolist() # Make sure all slopes are within REQUIRE min/max 16-bit int for idx in range(len(table) - 1): slope = table[idx + 1] - table[idx] # Alter the next table entry to force the slope to be ok if slope > 32767: table[idx + 1] -= slope - 32767 if slope < -32768: table[idx + 1] -= slope + 32768 slope = table[idx + 1] - table[idx] assert slope <= 32767 and slope >= -32768 arg_list.append( ( "", [table], ) ) return arg_list def agCondIf(testGen, opName, shapeList, dtype, error_name=None): # CondIf generates the condition values here. # Convert to tensors in the build function, along with the # then and else blocks arg_list = [] for c in [False, True]: arg_list.append(("cond{}".format(int(c)), [c])) return arg_list def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None): # While loop: 0 iterations, 1, more than 1 arg_list = [] for iter in [0, 1, 4]: arg_list.append(("iter{}".format(iter), [iter])) return arg_list