# Copyright (c) 2021-2022, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import itertools import math import warnings import numpy as np from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_utils import get_accum_dtype_from_tgTypes from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from serializer.tosa_serializer import DTypeNames from tosa.DType import DType from tosa.Op import Op from tosa.ResizeMode import ResizeMode # DTypeNames, DType, Op and ResizeMode are convenience variables to the # flatc-generated types that should be enums, but aren't class TosaQuantGen: """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion. """ def __init__(self): pass @staticmethod def getZeroPoint(testGen, dtype, error_name=None): if dtype == DType.INT8: if testGen.args.zeropoint is not None: return min(127, max(-128, testGen.args.zeropoint)) return testGen.randInt(-128, 128) elif dtype == DType.UINT8: if testGen.args.zeropoint is not None: return min(255, max(0, testGen.args.zeropoint)) return testGen.randInt(0, 256) elif error_name in [ ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero, ]: zero_point = testGen.randInt(-128, 128) if zero_point == 0: zero_point = 1 return zero_point return 0 @staticmethod def qgUnary(testGen, op, dtype, error_name=None): if error_name == ErrorIf.InputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype, error_name), TosaQuantGen.getZeroPoint(testGen, dtype), ] elif error_name == ErrorIf.OutputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype), TosaQuantGen.getZeroPoint(testGen, dtype, error_name), ] else: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype), TosaQuantGen.getZeroPoint(testGen, dtype), ] return qinfo @staticmethod def qgConv(testGen, op, dtype_or_dtypeList, error_name=None): if isinstance(dtype_or_dtypeList, list): # a list of [input, weights, accumulator] dtypes dtypeList = dtype_or_dtypeList else: # an int, [input, weights, accumulator] dtypes are the same dtypeList = [dtype_or_dtypeList] * 3 if error_name == ErrorIf.InputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name), TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), ] elif error_name == ErrorIf.WeightZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name), ] else: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]), TosaQuantGen.getZeroPoint(testGen, dtypeList[1]), ] return qinfo @staticmethod def qgMatmul(testGen, op, dtype, error_name=None): if error_name == ErrorIf.InputZeroPointNotZero: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype, error_name), TosaQuantGen.getZeroPoint(testGen, dtype, error_name), ] else: qinfo = [ TosaQuantGen.getZeroPoint(testGen, dtype), TosaQuantGen.getZeroPoint(testGen, dtype), ] return qinfo @staticmethod def computeMultiplierAndShift(scaleFp, scale32): # Derived from computeMultiplierAndShiftTosaScale32 # Provide a floating-point scaling factor and the scale32 parameter # to compute the multiplier and shift if scale32: scaleBits = 31 else: scaleBits = 15 m, shift = math.frexp(scaleFp) if scaleFp < 0.0: m = -m multiplier = round(m * (1 << scaleBits)) assert multiplier <= (1 << scaleBits) if multiplier == (1 << scaleBits): multiplier = multiplier // 2 shift = shift + 1 shift = (-shift) + scaleBits # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format( # scaleFp, scaleBits, m, multiplier, shift)) # Adjust multiplier such that shift is in allowed value range. if shift == 0: multiplier = multiplier // 4 shift = shift + 2 elif shift == 1: multiplier = multiplier // 2 shift = shift + 1 elif shift == 63: multiplier = multiplier * 2 shift = shift - 1 assert multiplier <= (1 << scaleBits) assert shift >= 2 and shift <= 62 return multiplier, shift class TosaTensorGen: """Tensor generators create a shape list for the placeholder and const tensor data operands for the operator. The actual random data is generated separately for each test. """ def __init__(self): pass @staticmethod def tgBasic(testGen, opName, rank, error_name=None): pl, const = opName["operands"] shape = testGen.makeShape(rank) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: shape = TosaErrorIfArgGen.eiRestrictDimensions(shape) shape_list = [] for i in range(pl + const): shape_list.append(shape.copy()) if error_name == ErrorIf.RankMismatch: if rank == 1 and i != 1: shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) elif i != 1: shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1])) return shape_list @staticmethod def tgNHWC(testGen, opName, rank, error_name=None): pl, const = opName["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 shape = testGen.makeShape(rank) # Constrict the batch size? if testGen.args.max_batch_size: shape[0] = (shape[0] % testGen.args.max_batch_size) + 1 # Constrict the overall size of the shape when creating ERROR_IF tests if error_name and error_name != ErrorIf.MaxDimExceeded: shape = TosaErrorIfArgGen.eiRestrictDimensions(shape) shape_list = [] for i in range(pl + const): shape_list.append(shape.copy()) return shape_list @staticmethod def tgScatter(testGen, opName, rank, error_name=None): pl, const = opName["operands"] assert pl == 2 assert const == 0 if error_name != ErrorIf.WrongRank: assert rank == 3 values_in_shape = testGen.makeShape(rank) # ignore max batch size if target shape is set if testGen.args.max_batch_size and not testGen.args.target_shapes: values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1 W = testGen.randInt( testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1] ) # Constrict W if one dimension is too large to keep tensor size reasonable if max(values_in_shape) > 5000: W = testGen.randInt(0, 16) input_shape = [values_in_shape[0], W, values_in_shape[2]] shape_list = [] shape_list.append(values_in_shape.copy()) shape_list.append(input_shape.copy()) return shape_list @staticmethod def tgBroadcastFuzz(testGen, op, rank, error_name=None): shape = testGen.makeShape(rank) pl, const = op["operands"] shape_list = [] # Choose one of the inputs to broadcast # Note: Simplifies OutputShaper code if we don't change first shape for errors bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const) for i in range(pl + const): shape_bcast = shape.copy() # If the chosen input, pick a random index to broadcast if i == bcast_idx: fuzz_idx = testGen.randInt(0, rank) if error_name == ErrorIf.DimensionMismatch: shape_bcast[fuzz_idx] += 1 elif error_name == ErrorIf.RankMismatch: # Add one rank to the shape (or more for rank of 1) extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1 shape_bcast = np.concatenate( (shape_bcast, testGen.makeShape(extra_ranks)) ) if rank != 1: # Either keep the extra rank, or remove it new_len = testGen.rng.choice([-2, len(shape_bcast)]) shape_bcast = shape_bcast[:new_len] else: shape_bcast[fuzz_idx] = 1 shape_list.append(shape_bcast) return shape_list @staticmethod def tgConv2D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 # IFM dimensions are NHWC ifm_shape = testGen.makeShape(rank) # Constrict the batch size? if testGen.args.max_batch_size: ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter height/width from the operator parameters filter_hw = op["filter"] # Generate a random OFM depth ofm_depth = testGen.makeShape(1)[0] # The filter dimensions are OHWI filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]]) # The bias is OC bias_shape = np.asarray([ofm_depth]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgConv3D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 5 # IFM dimensions are NDHWC ifm_shape = testGen.makeShape(rank) # Constrict the batch size? if testGen.args.max_batch_size: ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter depth/height/width from the operator parameters filter_dhw = op["filter"] # Generate a random OFM channel ofm_channel = testGen.makeShape(1)[0] # The filter dimensions are ODHWI filter_shape = np.asarray( [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]] ) # The bias is OC bias_shape = np.asarray([ofm_channel]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgTransposeConv2D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 # IFM dimensions are NHWC ifm_shape = testGen.makeShape(rank) # Constrict the batch size? if testGen.args.max_batch_size: ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter height/width from the operator parameters filter_hw = op["filter"] # Generate a random OFM depth ofm_depth = testGen.makeShape(1)[0] # The filter dimensions are OHWI filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]]) # The bias is OC bias_shape = np.asarray([ofm_depth]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgDepthwiseConv2D(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 4 assert pl == 1 and const == 2 # IFM dimensions are NHWC ifm_shape = testGen.makeShape(rank) # Constrict the batch size? if testGen.args.max_batch_size: ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( ifm_shape, max_dim=24, max_items=10000 ) # Get the filter height/width from the operator parameters # Filter is KH, HW, C, M filter_hw = op["filter"] # Generate a random OFM depth, but don't let it get too big because # the output depth is M * C filter_m = ( testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4) ) + 1 # The filter dimensions are HWCM filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m]) # The bias is M * C bias_shape = np.asarray([ifm_shape[3] * filter_m]) return [ifm_shape, filter_shape, bias_shape] @staticmethod def tgFullyConnected(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 2 input_shape = testGen.makeShape(rank) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape) filter_oc = testGen.rng.integers( low=testGen.args.tensor_shape_range[0], high=testGen.args.tensor_shape_range[1], size=1, )[0] filter_shape = np.asarray([filter_oc, input_shape[1]]) bias_shape = np.asarray([filter_oc]) return [input_shape, filter_shape, bias_shape] @staticmethod def tgMatmul(testGen, op, rank, error_name=None): pl, const = op["operands"] if error_name != ErrorIf.WrongRank: assert rank == 3 assert pl == 2 and const == 0 a_shape = testGen.makeShape(rank) # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape) # Get a random number for b_oc even if target shape is defined b_oc = np.int32( testGen.rng.integers( low=testGen.args.tensor_shape_range[0], high=testGen.args.tensor_shape_range[1], size=1, ) )[0] # If N or H is large let b_oc be 1 to reduce output tensor size if max(a_shape) > 1000: b_oc = 1 b_shape = np.asarray([a_shape[0], a_shape[2], b_oc]) return [a_shape, b_shape] @staticmethod def tgConcat(testGen, opName, rank, error_name=None): pl, const = opName["operands"] shape = testGen.makeShape(rank) # Create extra tensors to concat. # Take into account value of pl when getting maximum number of concats num_tensors = testGen.randInt(0, 4) shape_list = [] for i in range(pl + const + num_tensors): if error_name == ErrorIf.ConcatInputRankMismatch and i != 0: remove = testGen.rng.choice([True, False]) wrongShape = shape.copy() if remove and len(shape) > 1: wrongShape = wrongShape[1:] else: wrongShape = list(wrongShape) wrongShape.append(testGen.rng.integers(1, 10)) shape_list.append(wrongShape) else: shape_list.append(shape.copy()) return shape_list @staticmethod def tgConcatConstInput(testGen, shapeList, axis, error_name=None): if error_name in [ ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch, ]: return shapeList # Split concat shape along axis to allow for multiple const inputs # without making too many large tensors if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList): # If axis can't be split we still need to invalidate other dimensions if error_name == ErrorIf.ConcatInputDimMismatch: for shape in shapeList[1:]: # Negative test shapeLists are created individually for each test, # so no need to copy the shape before altering it. shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10) return shapeList # Create copy of shape we are going to split (so we don't alter shapeList) shape = shapeList[0].copy() # Add original shape as first input new_shapeList = [shape.copy()] length_on_axis = shape[axis] remaining_length = length_on_axis for i in range(len(shapeList) - 2): # Calculate split on axis and remaining value split_shape_val = int(shape[axis] / 2) remaining_length = remaining_length - split_shape_val # Append new shape, and set remaining shape shape[axis] = split_shape_val new_shapeList.append(shape.copy()) # invalidate dimensions if error_name == ErrorIf.ConcatInputDimMismatch: shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10) else: shape[axis] = remaining_length if i == len(shapeList) - 3: new_shapeList.append(shape.copy()) return new_shapeList class TosaTensorValuesGen: """Tensor Value generators create the random data for each test.""" def __init__(self): pass @staticmethod def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] tens = [] tens.extend( testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount]) ) tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:])) return tens @staticmethod def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: pCount, cCount = op["operands"] assert ( pCount == 1 and cCount == 0 ), "Op.NEGATE must have 1 placeholders, 0 consts" # Must create tensors with values within accumulator (int32) negatable # range max_val = (1 << 31) - 1 min_val = -max_val arr = np.int32( testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0]) ) placeholders = [] placeholders.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr) ) return placeholders else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32 and error_name is None: # Make sure the operation does not cause value saturation - where # the number wraps due to limited number of bits to store the answer pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts" placeholders = [] add = op["op"] == Op.ADD a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1]) if add: res_arr = np.add(a_arr, b_arr, dtype=np.int64) else: res_arr = np.subtract(a_arr, b_arr, dtype=np.int64) # Work out the saturation limits max_i32 = (1 << 31) - 1 min_i32 = -(1 << 31) max_arr = np.full(shapeList[1], max_i32) min_arr = np.full(shapeList[1], min_i32) # Find how much values exceed the maximum/minimums sat_max_arr = np.maximum(res_arr - max_arr, 0) sat_min_arr = np.minimum(res_arr - min_arr, 0) if not add: # Swap saturation values and negate values as we need to perform opposite operations sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr # Create new array of unsaturated values by clipping values as needed b_unsat_arr = b_arr if (sat_max_arr != 0).any(): # Clip values that cause saturation b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32) # Reduce axes in unsaturated tensor to match original tensor for axis, dim in enumerate(b_arr.shape): if dim != b_unsat_arr.shape[axis]: assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True) if (sat_min_arr != 0).any(): # Clip values that cause saturation b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32) # Reduce axes in unsaturated tensor to match original tensor for axis, dim in enumerate(b_arr.shape): if dim != b_unsat_arr.shape[axis]: assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True) placeholders.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) ) placeholders.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr) ) return placeholders else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgCondIfWhileLoop( testGen, op, dtypeList, shapeList, testArgs, error_name=None ): if dtypeList[0] in ( DType.INT32, DType.INT16, DType.INT8, ): # Limit input tensors with cond_if_binary or while_loop to stop # saturation of add/sub ops with int32 and keep all logical shift # values between 0 to 31 for int16 or int8 pCount, cCount = op["operands"] pRemain = pCount placeholders = [] for idx, shape in enumerate(shapeList[:]): if dtypeList[0] == DType.INT32: arr = testGen.getRandTensor(shapeList[idx], DType.INT16) else: arr = np.int32( testGen.rng.integers(low=0, high=32, size=shapeList[idx]) ) if pRemain > 0: placeholders.append( testGen.ser.addPlaceholder(shape, dtypeList[idx], arr) ) pRemain -= 1 else: placeholders.append( testGen.ser.addConst(shape, dtypeList[idx], arr) ) return placeholders else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgArithmeticRightShift( testGen, op, dtypeList, shapeList, testArgs, error_name=None ): pCount, cCount = op["operands"] # Force value of operand[1] to be within [0, num_bits] assert ( pCount == 2 and cCount == 0 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts" placeholders = [] for idx, shape in enumerate(shapeList[:]): if idx == 1: if dtypeList[idx] == DType.INT8: arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape)) elif dtypeList[idx] == DType.INT16: arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape)) elif dtypeList[idx] == DType.INT32: arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape)) elif error_name == ErrorIf.WrongInputType: arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape)) else: raise Exception("OpArithmeticRightShift: invalid input dtype") else: arr = testGen.getRandTensor(shape, dtypeList[idx]) placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)) return placeholders @staticmethod def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None): # Set datatype of condition tensor to boolean dtypeList[0] = DType.BOOL return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.INTDIV must have 2 placeholders, 0 consts" placeholders = [] # Two invalid cases for Op.INTDIV: # 1. divisor == 0 # 2. dividend == -(1<<31) and divisor == -1 while True: dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1]) if (divisor_arr == 0).any(): continue if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any(): continue break placeholders.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr) ) placeholders.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr) ) return placeholders else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] if dtypeList[0] in (DType.FP16, DType.FLOAT): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] # Make sure multiply result in int32 range shift = testArgs[0] if dtypeList[0] == DType.INT8: num_bits = 8 elif dtypeList[0] == DType.INT16: num_bits = 16 elif dtypeList[0] == DType.INT32: num_bits = 32 elif error_name == ErrorIf.WrongInputType: num_bits = 8 else: raise Exception("OpMul: invalid input dtype") for idx, shape in enumerate(shapeList[:]): low = -(2 ** (num_bits - 1)) high = (2 ** (num_bits - 1)) - 1 a_arr = np.int32( testGen.rng.integers(low=low, high=high, size=shapeList[0]) ) b_arr = np.int32( testGen.rng.integers(low=low, high=high, size=shapeList[1]) ) i = 0 while True: a_arr_64 = a_arr.astype(np.int64) b_arr_64 = b_arr.astype(np.int64) if shift > 0: rounding = 1 << (shift - 1) result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift else: result_arr = a_arr_64 * b_arr_64 if (result_arr > -(2**31)).all() and ( result_arr <= ((2**31) - 1) ).all(): break i = i + 1 a_arr = a_arr // 2 b_arr = b_arr // 2 placeholders.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) ) placeholders.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) ) tens.extend(placeholders) return tens else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None): count = len(shapeList) - testGen.args.num_const_inputs_concat if count < 1: count = 1 if testGen.args.num_const_inputs_concat == 0: count = len(shapeList) # Ensure axis is an int testArgs[0] = int(testArgs[0]) shapeList = TosaTensorGen.tgConcatConstInput( testGen, shapeList, testArgs[0], error_name ) tens = [] tens.extend( testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count]) ) tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:])) return tens @staticmethod def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts" values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1])) placeholders = [] placeholders.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr) ) placeholders.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr) ) return placeholders @staticmethod def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if error_name is None: pCount, cCount = op["operands"] assert ( pCount == 2 and cCount == 0 ), "Op.EQUAL must have 2 placeholders, 0 consts" a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0]) b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1]) # Using random numbers means that it will be very unlikely that # there are any matching (equal) values, therefore force that # there are twice the number of matching values as the tensor rank for num in range(0, len(shapeList[0]) * 2): a_index = [] b_index = [] # Choose an index in each axis for the whole shape for axis in range(0, len(shapeList[0])): # Index can be up to the largest dimension in both shapes index = np.int32( testGen.rng.integers( 0, max(shapeList[0][axis], shapeList[1][axis]) ) ) # Reduce the index down to a shape's dim for broadcasting a_index.append(min(shapeList[0][axis] - 1, index)) b_index.append(min(shapeList[1][axis] - 1, index)) a_arr[tuple(a_index)] = b_arr[tuple(b_index)] placeholders = [] placeholders.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) ) placeholders.append( testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) ) return placeholders else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) @staticmethod def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None): if dtypeList[0] == DType.INT32: pCount, cCount = op["operands"] assert ( pCount == 1 and cCount == 0 ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts" # Limit values so that the sum cannot exceed the range of an int32 during # summation of any axis range_val = int((1 << 31) / max(shapeList[0])) values_arr = np.int32( testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0]) ) placeholders = [] placeholders.append( testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr) ) return placeholders else: return TosaTensorValuesGen.tvgDefault( testGen, op, dtypeList, shapeList, testArgs, error_name ) class TosaArgGen: """Argument generators create exhaustive or random lists of attributes for operators that take attributes or other parameters. The return value is a list of (descriptive_name, [arglist]) tuples where the descriptive_name is appended to the test name and the arglist is expanded as arguments to the operator build function. """ def __init__(self): pass @staticmethod def agNone(testGen, opName, shapeList, dtype, error_name=None): """A trivial argument generator for operators that don't take any non-tensor arguments""" return [("", [])] @staticmethod def agAxis(testGen, opName, shapeList, dtype, error_name=None): """Build the axis argument for operators that take a single axis""" axes = [] shape = shapeList[0] if error_name == ErrorIf.AxisSmallerZero: small_axis = testGen.rng.integers(-5, 0) axes.append(("axis{}".format(small_axis), [small_axis])) elif error_name == ErrorIf.AxisLargerRank: large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10) axes.append(("axis{}".format(large_axis), [large_axis])) else: for a in range(0, len(shape)): axes.append(("axis{}".format(a), [a])) return axes @staticmethod def agConv(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] filter_shape = shapeList[1] # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3]) k = [int(x) for x in opName.split("_")[-1].split("x")] accum_dtype = get_accum_dtype_from_tgTypes(dtypes) # Check the rank rank = 5 if opName.startswith("conv3d") else 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == rank assert len(filter_shape) == rank # kernel rank omits batch and channels k_rank = rank - 2 assert len(k) == k_rank # Generate comprehensive argument lists # - except for named errors, which use specific invalid value(s) if error_name == ErrorIf.PadSmallerZero: p_vals = [testGen.rng.choice(range(-5, 0))] else: p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)] paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))} if error_name == ErrorIf.StrideSmallerOne: # Can't use stride=0, as it is used to derive output shape, as a divisor s_vals = [testGen.rng.choice(range(-5, 0))] else: # Stride must be greater than 1 to force non-integer error startStride = 1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2 s_vals = [x for x in range(startStride, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * k_rank))} if error_name == ErrorIf.DilationSmallerOne: d_vals = [testGen.rng.choice(range(-5, 1))] else: d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)] dilations = {x for x in itertools.product(*([d_vals] * k_rank))} if not error_name and testGen.args.oversize: # add some oversize argument values if max(ifm_shape) < 64: bigPadding = 9 paddings.update( {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))} ) bigStride = 8 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))}) bigDilation = 7 dilations.update( {x for x in itertools.product(*([[1, bigDilation]] * k_rank))} ) # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 120 sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 # To get a variety of parameter combinations sparsity should not be a # multiple of 2, 3 or 5 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: sparsity += 1 n = 0 for s in sorted(list(strides)): for p in sorted(list(paddings)): for d in sorted(list(dilations)): if ( n % sparsity == 0 # the padded shape must exceed the dilation * kernel to get a positive # sized output shape and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k[0] - 1) and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k[1] - 1) and ( k_rank < 3 or ((ifm_shape[3] - 1 + p[4] + p[5]) > d[2] * (k[2] - 1)) ) ): remainders = [] for index in range(k_rank): pad_offset = index * 2 remainders.append( ( ifm_shape[index + 1] - 1 + p[pad_offset] + p[pad_offset + 1] - (k[index] - 1) * d[index] ) % s[index] ) if ( # the parameters must produce integer exact output error_name != ErrorIf.ConvOutputShapeNonInteger and max(remainders) == 0 ) or ( error_name == ErrorIf.ConvOutputShapeNonInteger and max(remainders) > 0 ): arg_list.append( ( "acc{}_st{}_pad{}_dilat{}".format( testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "".join([str(x) for x in d]), ), [accum_dtype, s, p, d], ) ) n += 1 return arg_list @staticmethod def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): if isinstance(dtypes, list) or isinstance(dtypes, tuple): input_dtype = dtypes[0] else: input_dtype = dtypes if error_name == ErrorIf.WrongOutputType: accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype) elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect accum_dtype = DType.INT32 else: accum_dtype = get_accum_dtype_from_tgTypes(dtypes) return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])] @staticmethod def agMatMul(testGen, opName, shapeList, dtype, error_name=None): # Get valid accumulate type(s) if dtype == DType.INT8: accum_dtypes = [DType.INT32] elif dtype == DType.INT16: accum_dtypes = [DType.INT48] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FLOAT] elif dtype == DType.FLOAT: accum_dtypes = [DType.FLOAT] elif error_name is None: assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" if error_name == ErrorIf.WrongOutputType: # Get incorrect output dtype for ErrorIf case accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect accum_dtypes = [DType.INT32] return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes] @staticmethod def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] filter_shape = shapeList[1] accum_dtype = get_accum_dtype_from_tgTypes(dtypes) # Must be rank 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == 4 assert len(filter_shape) == 4 # Generate comprehensive argument lists # - except for named errors, which use specific invalid value(s) smallest_padding_size = -min(filter_shape[1], filter_shape[2]) + 1 if error_name == ErrorIf.PadLargerEqualKernel: max_filter_size = -max(filter_shape[1], filter_shape[2]) p_vals = [testGen.rng.choice(range(max_filter_size - 10, max_filter_size))] else: p_vals = [ x for x in range(smallest_padding_size, testGen.args.max_conv_padding + 1) ] paddings = {x for x in itertools.product(*([p_vals] * 4))} if error_name == ErrorIf.StrideSmallerOne: # Can't use stride=0, as it is used to derive output shape, as a divisor s_vals = [testGen.rng.choice(range(-5, 0))] else: s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * 2))} if not error_name and testGen.args.oversize: # add some oversize argument values if max(ifm_shape) < 64: bigPadding = 9 paddings.update( { x for x in itertools.product( *([[smallest_padding_size, bigPadding]] * 4) ) } ) bigStride = 8 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))}) # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 10 sparsity = len(paddings) * len(strides) // sparsity_factor + 1 # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 # To get a variety of parameter combinations sparsity should not be a # multiple of 2, 3 or 5 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: sparsity += 1 n = 0 for s in sorted(list(strides)): for p in sorted(list(paddings)): if n % sparsity == 0: # Determine the output shape oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + filter_shape[1] ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + filter_shape[2] os = [ifm_shape[0], oh, ow, filter_shape[0]] arg_list.append( ( "acc{}_st{}_pad{}_os{}".format( testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "x".join([str(x) for x in os]), ), [accum_dtype, s, p, os], ) ) n += 1 return arg_list @staticmethod def agPad(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] rank = len(shapeList[0]) # Exhaustively test combinations of padding on each side of each dimension # - the range of padding values is defined by pad_min and pad_max # - for padding >9, the name format needs to be more distinctive pad_min, pad_max = 0, 1 pad_values = [x for x in range(pad_min, pad_max + 1)] if error_name == ErrorIf.PadSmallerZero: pad_values = [x for x in range(-2, 0)] axis_pad_values = [x for x in itertools.product(pad_values, pad_values)] shape_pad_values = itertools.product(*([axis_pad_values] * rank)) if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 elif dtype in (DType.FP16, DType.FLOAT): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: return [] for paddings in shape_pad_values: paddings = list(paddings) args_valid = True if error_name == ErrorIf.PadSmallerZero: # Prevent negative output shapes while ensuring still testing for negative padding for i in range(rank): dim_after_padding = ( paddings[i][0] + paddings[i][1] + shapeList[0][i] ) if dim_after_padding < 1: paddings[i] = (0, 0) if all([p > -1 for p in paddings[i]]): args_valid = False if args_valid: name = "pad" for r in range(rank): before, after = paddings[r] name = f"{name}{before}{after}" arg_list.append( (name, [np.array(paddings), pad_const_int, pad_const_fp]) ) if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0: warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}") return arg_list @staticmethod def agPooling(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] shape = shapeList[0] if error_name != ErrorIf.WrongRank: assert len(shape) == 4 # Generate comprehensive argument lists p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)] paddings = {x for x in itertools.product(*([p_vals] * 4))} # Stride must be greater than 1 to force non-integer error startStride = 1 if error_name != ErrorIf.PoolingOutputShapeNonInteger else 2 s_vals = [x for x in range(startStride, testGen.args.max_pooling_stride + 1)] strides = {x for x in itertools.product(*([s_vals] * 2))} k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)] kernels = {x for x in itertools.product(*([k_vals] * 2))} if opName == "max_pool2d": accum_dtypes = [None] # max_pool has no accumulate dtype elif dtype == DType.INT8 or dtype == DType.INT16: accum_dtypes = [DType.INT32] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FLOAT] elif dtype == DType.FLOAT: accum_dtypes = [DType.FLOAT] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" else: # Set to something for the ErrorIf case which has # incorrect input data-type accum_dtypes = [DType.INT32] if testGen.args.oversize: # add some oversize argument values bigStride = 7 strides.update( {x for x in itertools.product(*([[startStride, bigStride]] * 2))} ) bigKernel = 9 kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))}) if max(shape) < 64: # padding must be less than the kernel size bigPadding = bigKernel - 1 paddings.update( {x for x in itertools.product(*([[0, bigPadding]] * 4))} ) # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 500 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 arg_str = ( "acc{}_st{}_kern{}_pad{}" if accum_dtypes[0] is not None else "st{}_kern{}_pad{}" ) def get_arg_list_element(accum, stride, pad, kern): # Return tuple containing the formatted argument string and # the corresponding argument values arg_str_elems = [ "".join([str(x) for x in stride]), "".join([str(x) for x in kern]), "".join([str(x) for x in pad]), ] # Note: different order to string arg_val_elems = [stride, pad, kern] if accum is not None: arg_str_elems.insert(0, testGen.typeStr(accum)) arg_val_elems.insert(0, accum) return (arg_str.format(*arg_str_elems), arg_val_elems) n = 0 for a in accum_dtypes: for s in sorted(list(strides)): for p in sorted(list(paddings)): for k in sorted(list(kernels)): if error_name in [ ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel, ]: sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( testGen, error_name, s, p, k ) if None not in [sNew, pNew, kNew] and n % sparsity == 0: arg_vals = [a, sNew, pNew, kNew] arg_list.append(get_arg_list_element(*arg_vals)) elif ( n % sparsity == 0 # padding must not exceed the kernel size and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1] # the padded shape must exceed the kernel size and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1] ): remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] if ( # the parameters must produce integer exact output error_name != ErrorIf.PoolingOutputShapeNonInteger and remainder_h == 0 and remainder_w == 0 ) or ( error_name == ErrorIf.PoolingOutputShapeNonInteger and (remainder_h != 0 or remainder_w != 0) ): arg_vals = [a, s, p, k] arg_list.append(get_arg_list_element(*arg_vals)) n += 1 return arg_list @staticmethod def agCast(testGen, opName, shapeList, inDtype, error_name=None): arg_list = [] # Enumerate the output types here if error_name == ErrorIf.WrongOutputType: dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype) elif inDtype == DType.INT8: dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT] elif inDtype == DType.INT16: dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT] elif inDtype == DType.INT32: dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FLOAT: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] else: raise Exception("Unexpected input dtype: {}".format(inDtype)) for dtype in dtypeList: arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype])) return arg_list @staticmethod def agRescale(testGen, opName, shapeList, inDtype, error_name=None): arg_list = [] # Enumerate the output types here for outDtype in [ DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.UINT16, ]: if ( outDtype in [DType.UINT8, DType.INT8, DType.UINT16] and error_name == ErrorIf.OutputZeroPointNotZero ): continue if ( outDtype != DType.UINT16 and error_name == ErrorIf.U16OutputZeroPointNotValid ) or ( inDtype != DType.UINT16 and error_name == ErrorIf.U16InputZeroPointNotValid ): # ErrorIfs only valid with UINT16 continue if ( inDtype == DType.UINT8 and outDtype not in [DType.INT8, DType.INT16] and error_name != ErrorIf.WrongOutputType ): # The only output dtypes for UINT8 are INT8/INT16, skip all others continue if ( inDtype not in [DType.INT8, DType.INT16] and outDtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType ): # The only input dtypes for UINT8 are INT8/INT16, skip all others continue if ( inDtype == DType.UINT16 and outDtype != DType.INT16 and error_name != ErrorIf.WrongOutputType ): # The only output dtype for UINT16 is INT16, skip all others continue if ( inDtype != DType.INT16 and outDtype == DType.UINT16 and error_name != ErrorIf.WrongOutputType ): # The only input dtype for UINT16 is INT16, skip all others continue if ( error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype) ): continue for scale32 in [False, True]: if error_name == ErrorIf.ScaleTrue and not scale32: continue elif error_name == ErrorIf.ScaleNotTrue and scale32: continue for double_round in [False, True]: if error_name == ErrorIf.ScaleNotTrue and not double_round: continue for per_channel in [False, True]: if ( inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue ): # Illegal condition. Must be scale32=False continue if ( double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue ): # Illegal condition. ERROR_IF(!scale32 && double_round) continue arg_list.append( ( "out{}_sc{}_dr{}_pc{}".format( DTypeNames[outDtype], int(scale32), int(double_round), int(per_channel), ), [outDtype, scale32, double_round, per_channel], ) ) return arg_list @staticmethod def agMul(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] if dtype is DType.INT32: for p in range(testGen.args.num_rand_permutations): shift = testGen.randInt(0, 32) arg_list.append(("perm{}_shift{}".format(p, shift), [shift])) else: arg_list.append(("perm0_shift0", [0])) return arg_list @staticmethod def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] arg_list.append(("roundTrue", [True])) arg_list.append(("roundFalse", [False])) return arg_list # Helper function for reshape. Gets some factors of a larger number. @staticmethod def getFactors(val, start=1): factors = [] for i in range(start, int(np.sqrt(val)) + 1): if (val % i) == 0: factors.append(i) return factors @staticmethod def agReshape(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] origShape = shapeList[0] totalElements = 1 for s in origShape: totalElements *= s # This code is NOT fast. Fortunately, the numbers are fairly small. factors = TosaArgGen.getFactors(totalElements) for p in range(testGen.args.num_rand_permutations): newRank = testGen.randInt(1, 7) if len(factors) < newRank: continue found = True # escape_counter breaks while loop if it continues on for too long escape_counter = 0 while found: newShape = [] # Generate newShape ensuring it isn't a duplicate remainingElements = totalElements shuffledFactors = testGen.rng.permutation(factors) for i in range(1, newRank): # pick rank-1 factors newShape.append(shuffledFactors[0]) remainingElements = remainingElements // shuffledFactors[0] shuffledFactors = testGen.rng.permutation( TosaArgGen.getFactors(remainingElements) ) newShape.append(remainingElements) # Check for duplicates found = False for name, other_shape in arg_list: if other_shape[0] == newShape: found = True break escape_counter += 1 if escape_counter >= 100: break if not found: arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape])) return arg_list @staticmethod def agTranspose(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] if error_name == ErrorIf.IndexOutsideBounds: incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1) incorrect_small_index = range(-len(ifm_shape), 0) permutations = [p for p in itertools.permutations(incorrect_large_index)] permutations.extend( [p for p in itertools.permutations(incorrect_small_index)] ) elif error_name == ErrorIf.IndexUsedTwice: # Create list with a duplicated index perm_range = list(range(len(ifm_shape))) index_choice = testGen.rng.choice(range(len(perm_range))) perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice] permutations = [p for p in itertools.permutations(perm_range)] else: # Get all permutations permutations = [p for p in itertools.permutations(range(len(ifm_shape)))] # Limit to possible permutations from shape dimension or argument setting limit = min(len(permutations), testGen.args.num_rand_permutations) # Get random permutation generator that uses all permutations random_permutations = testGen.rng.permutation(permutations) # Create list of required amount of permutations arg_list = [ ("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit) ] return arg_list @staticmethod def agSlice(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] rank = len(ifm_shape) for p in range(testGen.args.num_rand_permutations): start = [] size = [] valid = True for i in range(rank): if ifm_shape[i] > 1: start.append(testGen.randInt(0, ifm_shape[i])) size.append(testGen.randInt(0, ifm_shape[i] - start[i])) # Invalid slice size? if size[i] == 0: valid = False else: start.append(0) size.append(1) if valid: # If ERROR_IF test required then incorrect start, size will be returned start, size = TosaErrorIfArgGen.eiSliceErrorIf( testGen, error_name, ifm_shape, start, size ) arg_list.append(("perm{}".format(p), [start, size])) return arg_list @staticmethod def agTile(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] rank = len(ifm_shape) for p in range(testGen.args.num_rand_permutations): # Pick a few random, but small multiple values # because otherwise this has a tendency to generate # enormous tensors multiples = [] for i in range(rank): if ifm_shape[i] > 1000: # Multiple of 1 if ifm_shape dimension is large to reduce # tensor size multiples.append(1) elif max(ifm_shape) > 1000: multiples.append(2) else: multiples.append(testGen.randInt(1, 4)) arg_list.append(("perm{}".format(p), [multiples])) return arg_list @staticmethod def agResize(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] ifm_shape = shapeList[0] def get_aspect_ratio_resize_params(): common_aspect_ratios = ((3, 2), (16, 9), (4, 3)) aspect_ratio = testGen.rng.choice(common_aspect_ratios) invert = testGen.rng.choice((False, True)) letterbox = testGen.rng.choice((False, True)) scale_y_n = aspect_ratio[0] if invert else aspect_ratio[1] scale_x_n = aspect_ratio[1] if invert else aspect_ratio[0] scale_y_d = scale_x_d = 1 offset_x = offset_y = 0 if letterbox: max_border = scale_y_n border_y = testGen.randInt(low=0, high=max_border) border_x = 0 else: # Pillarboxing border_y = 0 max_border = scale_x_n border_x = testGen.randInt(low=0, high=max_border) scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d) offset = (offset_y, offset_x) border = (border_y, border_x) return scale, offset, border def get_upscale_downscale_params(): valid_params = False while not valid_params: upscale = testGen.rng.choice((False, True)) # True if sampling begins from (0,0). Otherwise (-0.5,-0.5) origin_sampling = testGen.rng.choice((False, True)) if upscale: shift = testGen.randInt(low=1, high=4) scale_x_d = scale_y_d = 1 scale_x_n = scale_y_n = ( 1 << shift if origin_sampling else 2 << shift ) border_x = border_y = 0 if origin_sampling else (1 << shift) - 1 offset_x = offset_y = 0 if origin_sampling else -(1 << shift) + 1 else: scale_x_n = 1 scale_y_n = 1 # Return list of valid scale_*_d values (max value 4) given input dim shape def get_valid_denom(ifm_dim): return [x for x in range(1, 5) if ifm_dim % x == 1] # Generate list of valid downscale values and choose one randomly valid_scale_y_ds = get_valid_denom(ifm_shape[1]) valid_scale_x_ds = get_valid_denom(ifm_shape[2]) if not valid_scale_y_ds and not valid_scale_x_ds: # Bad parameters, skip continue if not valid_scale_y_ds: scale_y_d = 1 else: scale_y_d = testGen.rng.choice(valid_scale_y_ds) if not valid_scale_x_ds: scale_x_d = 1 else: scale_x_d = testGen.rng.choice(valid_scale_x_ds) border_x = border_y = 0 offset_y = testGen.randInt(0, 16 * scale_y_n) offset_x = testGen.randInt(0, 16 * scale_x_n) valid_params = True scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d) offset = (offset_y, offset_x) border = (border_y, border_x) return scale, offset, border def get_rand_params(): # Scale scale_y_n = testGen.randInt(low=1, high=(1 << 11)) scale_x_n = testGen.randInt(low=1, high=(1 << 11)) scale_y_d = testGen.randInt(low=1, high=(16 * scale_y_n)) scale_x_d = testGen.randInt(low=1, high=(16 * scale_x_n)) # Offsets and border within the scale offset_y = testGen.randInt(low=-scale_y_n, high=(16 * scale_y_n)) offset_x = testGen.randInt(low=-scale_x_n, high=(16 * scale_x_n)) border_y = testGen.randInt(low=(-16 * scale_y_n), high=scale_y_n) border_x = testGen.randInt(low=(-16 * scale_x_n), high=scale_x_n) scale = (scale_y_n, scale_y_d, scale_x_n, scale_x_d) offset = (offset_y, offset_x) border = (border_y, border_x) return scale, offset, border for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]: # Exclude illegal {mode, type} configurations. Pick legal output types if mode == ResizeMode.NEAREST and dtype == DType.INT8: outputDTypeList = [DType.INT8] elif mode == ResizeMode.NEAREST and dtype == DType.INT16: outputDTypeList = [DType.INT16] elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: outputDTypeList = [DType.INT32] elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: outputDTypeList = [DType.INT48] elif dtype == DType.FP16: outputDTypeList = [DType.FP16] elif dtype == DType.FLOAT: outputDTypeList = [DType.FLOAT] elif error_name == ErrorIf.WrongInputType: # If an incorrect input type is used then we set a 'correct' # output type to avoid other errors outputDTypeList = [DType.INT8, DType.INT16, DType.INT32] else: continue arg_str = "mode{}_out{}_sc{}x{}x{}x{}_off{}x{}_bor{}x{}" for outputDType in outputDTypeList: perm = 0 while perm < testGen.args.num_rand_permutations: # Random choice of type of params we are testing _rnd_param_fn = testGen.rng.choice( ( get_rand_params, get_upscale_downscale_params, get_aspect_ratio_resize_params, ) ) scale, offset, border = _rnd_param_fn() # Expand params for bounds-checking (scale_y_n, scale_y_d, scale_x_n, scale_x_d) = scale (offset_y, offset_x) = offset (border_y, border_x) = border # Make sure output dimensions OH and OW are integers partial_output_y = ( (ifm_shape[1] - 1) * scale_y_n - offset_y + border_y ) partial_output_x = ( (ifm_shape[2] - 1) * scale_x_n - offset_x + border_x ) if error_name == ErrorIf.ResizeOutputShapeNonInteger: if ( partial_output_y % scale_y_d == 0 and partial_output_x % scale_x_d == 0 ): # Skip this test as it doesn't produce NonInteger output perm += 1 continue else: while partial_output_y % scale_y_d != 0: scale_y_d -= 1 while partial_output_x % scale_x_d != 0: scale_x_d -= 1 output_y = partial_output_y // scale_y_d + 1 output_x = partial_output_x // scale_x_d + 1 if ( output_y >= testGen.args.max_resize_output_dim or output_x >= testGen.args.max_resize_output_dim ) and error_name is None: # Skip positive test if output dim will be too high # Avoid high test latency and OOM issues perm += 1 continue if ( output_y <= 0 or output_y >= MAX_RESIZE_DIMENSION or output_x <= 0 or output_x >= MAX_RESIZE_DIMENSION ): # Output dimensions out of scope if error_name is not None and perm > 0: # As long as we have one ERROR_IF test, don't worry # about creating all the other permutations perm += 1 continue if error_name == ErrorIf.ResizeOutputShapeMismatch and ( ( output_y + scale_y_d >= MAX_RESIZE_DIMENSION and output_y - scale_y_d < 1 ) or ( output_x + scale_x_d >= MAX_RESIZE_DIMENSION and output_x - scale_x_d < 1 ) ): # Can't create a negative test with these params as it # will create invalid output size if perm > 0: perm += 1 continue scale = [scale_y_n, scale_y_d, scale_x_n, scale_x_d] offset = [offset_y, offset_x] border = [border_y, border_x] # Common for all data types if error_name is not None: ( scale, offset, border, outputDTypeNew, ) = TosaErrorIfArgGen.eiResizeErrorIf( testGen, error_name, mode, dtype, shapeList, outputDType, scale, offset, border, ) else: outputDTypeNew = outputDType arg_to_append = ( arg_str.format( "N" if mode == ResizeMode.NEAREST else "B", testGen.typeStr(outputDTypeNew), scale[0], scale[1], scale[2], scale[3], offset[0], offset[1], border[0], border[1], ), [ mode, scale, offset, border, dtype, outputDTypeNew, ], ) if arg_to_append in arg_list: # Skip already generated test params continue # Valid permutation perm += 1 arg_list.append(arg_to_append) return arg_list @staticmethod def agTable(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] if dtype == DType.INT8: table = np.int32( testGen.rng.integers(low=-128, high=128, size=[256]) ).tolist() else: # INT16 table = np.int32( testGen.rng.integers(low=-32768, high=32768, size=[513]) ).tolist() # Make sure all slopes are within REQUIRE min/max 16-bit int for idx in range(len(table) - 1): slope = table[idx + 1] - table[idx] # Alter the next table entry to force the slope to be ok if slope > 32767: table[idx + 1] -= slope - 32767 if slope < -32768: table[idx + 1] -= slope + 32768 slope = table[idx + 1] - table[idx] assert slope <= 32767 and slope >= -32768 arg_list.append( ( "", [table], ) ) return arg_list def agCondIf(testGen, opName, shapeList, dtype, error_name=None): # CondIf generates the condition values here. # Convert to tensors in the build function, along with the # then and else blocks arg_list = [] for c in [False, True]: arg_list.append(("cond{}".format(int(c)), [c])) return arg_list def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None): # While loop: 0 iterations, 1, more than 1 arg_list = [] for iter in [0, 1, 4]: arg_list.append(("iter{}".format(iter), [iter])) return arg_list