# Copyright (c) 2020-2022, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import os from copy import deepcopy import numpy as np import serializer.tosa_serializer as ts from generator.tosa_arg_gen import TosaArgGen from generator.tosa_arg_gen import TosaQuantGen from generator.tosa_arg_gen import TosaTensorGen from generator.tosa_arg_gen import TosaTensorValuesGen from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import usableDTypes from tosa.DType import DType from tosa.Op import Op class TosaTestGen: # Maximum rank of tensor supported by test generator. TOSA_TENSOR_MAX_RANK = 6 def __init__(self, args): self.args = args self.basePath = args.output_dir self.random_seed = args.random_seed self.ser = None self.rng = np.random.default_rng(self.random_seed) self.createDynamicOpLists() self.initOpListDefaults() self.quantGen = TosaQuantGen() # Force makeShape to do a specific starting shape self.targetted_shape = None def createSerializer(self, opName, testPath): self.testPath = os.path.join(opName, testPath) fullPath = os.path.join(self.basePath, self.testPath) os.makedirs(fullPath, exist_ok=True) self.ser = ts.TosaSerializer(fullPath) def getSerializer(self): return self.ser def serialize(self, testName): with open( os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb" ) as fd: fd.write(self.ser.serialize()) with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd: fd.write(self.ser.writeJson("{}.tosa".format(testName))) def resetRNG(self, seed=None): if seed is None: seed = self.random_seed + 1 self.rng = np.random.default_rng(seed) def getRandTensor(self, shape, dtype): if dtype == DType.BOOL: return np.bool_(self.rng.choice(a=[False, True], size=shape)) # TOSA specific INT4 weight range from -7 to 7 elif dtype == DType.INT4: return np.int32(self.rng.integers(low=-7, high=8, size=shape)) elif dtype == DType.INT8: return np.int32(self.rng.integers(low=-128, high=128, size=shape)) elif dtype == DType.UINT8: return np.int32(self.rng.integers(low=0, high=256, size=shape)) elif dtype == DType.INT16: return np.int32(self.rng.integers(low=-32768, high=32768, size=shape)) elif dtype == DType.UINT16: return np.int32(self.rng.integers(low=0, high=65536, size=shape)) elif dtype == DType.INT32: return np.int32( self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape) ) elif dtype == DType.INT48: return np.int64( self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape) ) elif dtype == DType.FLOAT: return np.float32(self.rng.random(size=shape)) else: raise Exception("Unrecognized Dtype: {}".format(dtype)) def buildPlaceholderTensors(self, shape_list, dtype_list): placeholders = [] assert len(shape_list) == len(dtype_list) for idx, shape in enumerate(shape_list): arr = self.getRandTensor(shape, dtype_list[idx]) placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr)) return placeholders def buildConstTensors(self, shape_list, dtype_list): consts = [] assert len(shape_list) == len(dtype_list) for idx, shape in enumerate(shape_list): arr = self.getRandTensor(shape, dtype_list[idx]) consts.append(self.ser.addConst(shape, dtype_list[idx], arr)) return consts def makeShape(self, rank): if self.targetted_shape: return np.int32(self.targetted_shape) return np.int32( self.rng.integers( low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1], size=rank, ) ) def setTargetShape(self, shape): self.targetted_shape = shape def randInt(self, low=0, high=256): return np.int32(self.rng.integers(low=low, high=high, size=1))[0] def getRandNumberDType(self, dtype): if dtype == DType.FLOAT: return self.rng.random() elif dtype == DType.BOOL: return self.rng.choice([False, True]) # TOSA specific INT4 weight range from -7 to 7 elif dtype == DType.INT4: low, high = (-7, 8) elif dtype == DType.INT8: low, high = (-128, 128) elif dtype == DType.INT16: low, high = (-32768, 32768) elif dtype == DType.INT32: low, high = (-(1 << 31), (1 << 31)) elif dtype == DType.INT48: low, high = (-(1 << 47), (1 << 47)) # Special size return np.int64(self.rng.integers(low, high, size=1))[0] else: raise Exception("Unknown dtype: {}".format(dtype)) return np.int32(self.rng.integers(low, high, size=1))[0] def shapeStr(self, shape): sStr = [] # Convert to strings for i in shape: sStr.append(str(i)) return "x".join(sStr) def typeStr(self, t): if isinstance(t, list): assert len(t) >= 2 return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1])) else: if t == DType.BOOL: return "b" elif t == DType.INT4: return "i4" elif t == DType.INT8: return "i8" elif t == DType.UINT8: return "u8" elif t == DType.INT16: return "i16" elif t == DType.UINT16: return "u16" elif t == DType.INT32: return "i32" elif t == DType.INT48: return "i48" elif t == DType.FLOAT: return "float" else: raise Exception("Unknown dtype, cannot convert to string: {}".format(t)) def typeWidth(self, t): """Get the datatype width for integer types""" if t == DType.INT4: return 4 elif t == DType.INT8: return 8 elif t == DType.UINT8: return 8 elif t == DType.INT16: return 16 elif t == DType.UINT16: return 16 elif t == DType.INT32: return 32 elif t == DType.INT48: return 48 elif t == DType.FLOAT: return 32 elif t == DType.BOOL: return 1 else: raise Exception(f"Unknown dtype, cannot determine width: {t}") # Argument generators # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list]) # Where the string descriptor is used to generate the test name and # The build_fcn_arg_list is expanded and passed to the operator test # build function def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) # build_placeholder returns an int, ABS/other ops does not if isinstance(op, int): self.ser.addOperator(op, a.name, result_tens.name, None, qinfo) return result_tens elif op["op"] == Op.IDENTITY: self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo) return result_tens # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongOutputType: if result_tens.dtype not in [DType.INT8, DType.UINT8]: qinfo = ts.TosaSerializerQuantInfo() qinfo.UnaryQuantInfo( TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_dtype=a.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) return result_tens def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None): result_tens = OutputShaper.binaryBroadcastOp( self.ser, self.rng, a, b, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input1=a, input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None): result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b) self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name]) return result_tens def build_arithmetic_right_shift( self, op, a, b, round, validator_fcns=None, error_name=None ): result_tens = OutputShaper.binaryBroadcastOp( self.ser, self.rng, a, b, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input1=a, input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.ArithmeticRightShiftAttribute(round) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None): result_tens = OutputShaper.binaryBroadcastOp( self.ser, self.rng, a, b, error_name ) # Special for multiply: # Force the result to INT32 for INT types if a.dtype != DType.FLOAT: result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] outputDType = self.rng.choice(all_dtypes) result_tens.setDtype(outputDType) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input1=a, input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_table(self, op, a, table, validator_fcns=None, error_name=None): result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name) attr = ts.TosaSerializerAttribute() attr.TableAttribute(table) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None): result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name) # Invalidate Input/Output list for error if checks. input_list = [cond.name, a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input1=cond, input2=a, input3=b, input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator( op["op"], input_list, output_list, ) return result_tens def build_comparison(self, op, a, b, validator_fcns=None, error_name=None): result_tens = OutputShaper.binaryComparisonOp( self.ser, self.rng, a, b, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input1=a, input2=b, input_shape=a.shape, input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator( op["op"], input_list, output_list, ) return result_tens def build_argmax(self, op, a, axis, validator_fcns, error_name): result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, axis=axis, input_shape=a.shape, input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_pool2d( self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None, ): result_tens = OutputShaper.pool2dOp( self.ser, self.rng, input, kernel, stride, pad, error_name ) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType: if input.dtype not in [DType.INT8, DType.UINT8]: qinfo = ts.TosaSerializerQuantInfo() qinfo.UnaryQuantInfo( TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error if checks. input_list = [input.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=input.shape, input_dtype=input.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, kernel=kernel, stride=stride, pad=pad, qinfo=qinfo, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.PoolAttribute(kernel, stride, pad) self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_conv2d( self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None, ): assert len(padding) == 4 result_tens = OutputShaper.conv2dOp( self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name ) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( DType.INT8, DType.UINT8, ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, output_list=output_list, pad=padding, stride=strides, dilation=dilations, input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tens.shape, ): return None attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_conv3d( self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None, ): assert len(padding) == 6 result_tens = OutputShaper.conv3dOp( self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name ) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( DType.INT8, DType.UINT8, ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, output_list=output_list, pad=padding, stride=strides, dilation=dilations, input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tens.shape, ): return None attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_transpose_conv2d( self, op, ifm, filter, bias, stride, out_pad, output_shape, validator_fcns=None, error_name=None, qinfo=None, ): assert len(out_pad) == 4 result_tens = OutputShaper.transposeConv2DOp( self.ser, self.rng, ifm, output_shape, error_name ) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( DType.INT8, DType.UINT8, ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, output_list=output_list, pad=out_pad, stride=stride, input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tens.shape, ): return None attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute(out_pad, stride, output_shape) self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_depthwise_conv2d( self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None, ): result_tens = OutputShaper.depthwiseConv2dOp( self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name ) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( DType.INT8, DType.UINT8, ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, output_list=output_list, pad=padding, stride=strides, dilation=dilations, input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tens.shape, ): return None attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_fully_connected( self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None ): result_tens = OutputShaper.fullyConnectedOp( self.ser, self.rng, ifm, filter, error_name ) # Invalidate Input/Output list for error if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=ifm.shape, input_dtype=ifm.dtype, weight_dtype=filter.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) return result_tens def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, input_dtype=a.dtype, input2_shape=b.shape, input2_dtype=b.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) return result_tens def build_reduce(self, op, a, axis, validator_fcns, error_name=None): result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, axis=axis, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_clamp(self, op, a, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)] if error_name == ErrorIf.MaxSmallerMin: # Make sure the numbers are different to invoke this error while v[0] == v[1]: v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)] max_val = min(v) min_val = max(v) else: max_val = max(v) min_val = min(v) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, max_val=max_val, min_val=min_val, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() if a.dtype == DType.FLOAT: attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) attr = ts.TosaSerializerAttribute() attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT)) self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr) return result_tens # Needs an additional type/input def build_prelu(self, op, a, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) self.ser.addOperator(op["op"], [a.name], [result_tens.name]) return result_tens def build_sigmoid(self, op, a, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_tanh(self, op, a, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_concat(self, op, *a, validator_fcns=None, error_name=None): if error_name != ErrorIf.WrongInputType: assert type(a[-1]) == int # To store variable length list of input tensors we need to store axis along with it axis = a[-1] a = a[:-1] result_tens = OutputShaper.concatOp( self.ser, self.rng, axis, *a, error_name=error_name ) input_tensor_names = [] for tensor in a: input_tensor_names.append(tensor.name) # Invalidate Input/Output list for error if checks. input_list = input_tensor_names output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, axis=axis, input_shape=a[0].shape, output_shape=result_tens.shape, input_dtype=a[0].dtype, output_dtype=result_tens.dtype, inputs=a, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_pad( self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None, ): result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name) attr = ts.TosaSerializerAttribute() attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, pad=padding, qinfo=qinfo, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None): result_tens = OutputShaper.reshapeOp( self.ser, self.rng, a, newShape, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute(newShape) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, axis=axis, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None): result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name) attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(perms) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, output_shape=result_tens.shape, perms=perms, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None): result_tens = OutputShaper.sliceOp( self.ser, self.rng, a, start, size, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, start=start, size=size, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.SliceAttribute(start, size) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None): result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=a.shape, output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.TileAttribute(multiples) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_gather(self, op, values, validator_fcns=None, error_name=None): # Create a new indicies tensor # here with data that doesn't exceed the dimensions of the values tensor K = values.shape[1] # K W = self.randInt( self.args.tensor_shape_range[0], self.args.tensor_shape_range[1] ) # W indicies_arr = np.int32( self.rng.integers(low=0, high=K, size=[values.shape[0], W]) ) # (N, W) indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr) result_tens = OutputShaper.gatherOp( self.ser, self.rng, values, indicies, error_name ) # Invalidate Input/Output list for error if checks. input_list = [values.name, indicies.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=values.shape, output_shape=result_tens.shape, input_dtype=values.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None): # Create a new indicies tensor # here with data that doesn't exceed the dimensions of the values_in tensor K = values_in.shape[1] # K W = input.shape[1] # W indicies_arr = np.int32( self.rng.integers(low=0, high=K, size=[values_in.shape[0], W]) ) # (N, W) indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr) result_tens = OutputShaper.scatterOp( self.ser, self.rng, values_in, indicies, input, error_name ) # Invalidate Input/Output list for error if checks. input_list = [values_in.name, indicies.name, input.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=values_in.shape, output_shape=result_tens.shape, input_dtype=values_in.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_resize( self, op, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype, validator_fcns, error_name=None, ): result_tens = OutputShaper.resizeOp( self.ser, self.rng, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype, error_name, ) # Invalidate Input/Output list for error if checks. input_list = [input.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, mode=mode, shift=shift, input_dtype=input_dtype, output_dtype=output_dtype, input_shape=input.shape, output_shape=output_dims, offset=offset, offset_fp=offset_fp, stride=stride, stride_fp=stride_fp, input_list=input_list, output_list=output_list, result_tensor=result_tens, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.ResizeAttribute( output_dims, stride, offset, shift, stride_fp, offset_fp, mode ) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name) result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name) self.ser.addOperator( op, [val.name, val2.name], [result_tens.name, result_tens2.name] ) return result_tens def build_const(self, op, val, validator_fcns=None, error_name=None): self.ser.addOutputTensor(val) return val # Type Conversion def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None): result_tens = OutputShaper.typeConversionOp( self.ser, self.rng, val, out_dtype, error_name ) # Invalidate Input/Output list for error if checks. input_list = [val.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_shape=val.shape, output_shape=result_tens.shape, input_dtype=val.dtype, output_dtype=result_tens.dtype, result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_rescale( self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name, ): result_tens = OutputShaper.typeConversionOp( self.ser, self.rng, val, out_dtype, error_name ) if per_channel: nc = val.shape[-1] else: nc = 1 in_type_width = self.typeWidth(val.dtype) out_type_width = self.typeWidth(out_dtype) if val.dtype == DType.INT8: input_zp = self.randInt(-128, 128) in_type_width += 1 elif val.dtype == DType.UINT8: input_zp = self.randInt(0, 256) in_type_width += 1 elif error_name in [ ErrorIf.InputZeroPointNotZero, ErrorIf.U16InputZeroPointNotValid, ]: input_zp = self.randInt(-128, 128) if input_zp == 0: input_zp = input_zp + self.rng.integers(1, 10) in_type_width += 1 elif val.dtype == DType.UINT16: # Must come after ErrorIf.U16InputZeroPointNotValid check input_zp = self.rng.choice([0, 32768]) in_type_width += 1 else: input_zp = 0 if out_dtype == DType.INT8: output_zp = self.randInt(-128, 128) out_type_width += 1 elif out_dtype == DType.UINT8: output_zp = self.randInt(0, 256) out_type_width += 1 elif error_name in [ ErrorIf.OutputZeroPointNotZero, ErrorIf.U16OutputZeroPointNotValid, ]: output_zp = self.randInt(-128, 128) if output_zp == 0: output_zp = output_zp + self.rng.integers(1, 10) out_type_width += 1 elif out_dtype == DType.UINT16: # Must come after ErrorIf.U16OutputZeroPointNotValid check output_zp = self.rng.choice([0, 32768]) out_type_width += 1 else: output_zp = 0 # Calculate scale based on: # scale = a *(2^output_width)/(2^input_width)) a = np.float32(self.rng.random(size=[nc])) scale_arr = a * np.float32((1 << out_type_width) / (1 << in_type_width)) if scale32: pass # Cap the scaling at 2^31 - 1 for scale32 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1) else: # Cap the scaling at 2^15 - 1 for scale16 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0) # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr)) multiplier_arr = np.int32(np.zeros(shape=[nc])) shift_arr = np.int32(np.zeros(shape=[nc])) min_shift_value_arr = np.int64(np.zeros(shape=[nc])) max_shift_value_arr = np.int64(np.zeros(shape=[nc])) for i in range(nc): multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift( scale_arr[i], scale32 ) min_shift_value_arr[i] = -1 << (shift_arr[i] - 2) max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) if scale32 and error_name is None: # Make sure random values are within apply_scale_32 specification # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2)) assert val.placeholderFilename values = np.load( os.path.join(self.basePath, self.testPath, val.placeholderFilename) ) val_adj = np.subtract(values, input_zp, dtype=np.int64) val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=np.int64) val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=np.int64) val_adj = np.add(val_adj, input_zp, dtype=values.dtype) if not np.all(np.array_equal(values, val_adj)): # Values changed so overwrite file with new values np.save( os.path.join(self.basePath, self.testPath, val.placeholderFilename), val_adj, False, ) # Invalidate Input/Output list for error if checks. input_list = [val.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list ) qinfo = (input_zp, output_zp) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, input_dtype=val.dtype, output_dtype=out_dtype, input_shape=val.shape, qinfo=qinfo, scale32=scale32, double_round=double_round, input_list=input_list, output_list=output_list, result_tensor=result_tens, num_operands=num_operands, ): return None attr = ts.TosaSerializerAttribute() attr.RescaleAttribute( input_zp, output_zp, multiplier_arr, shift_arr, scale32, double_round, per_channel, ) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_cond_if_const( self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None ): # For cond_if with constants, we're supplied with then/else tensors that we ignore # (except for the generated shap) and the condition. Build Then/Else blocks # and fill them with const nodes for the body. # Condition tensor cond_tens = self.ser.addConst([], DType.BOOL, [cond]) # Make then/else tensors out_shape = then_tens.shape # Create an incorrect output shape for error_if tests if error_name in [ ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch, ]: incorrect_shape = deepcopy(then_tens.shape) for i in range(len(incorrect_shape)): incorrect_shape[i] += ( self.rng.choice([-3, -2, 2, 3]) if incorrect_shape[i] > 3 else self.rng.choice([1, 2, 4]) ) incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape)) then_arr = np.int32(self.rng.integers(0, 256, size=out_shape)) else_arr = np.int32(self.rng.integers(0, 256, size=out_shape)) # And the result tensor based on any of the outputs result_tens = self.ser.addOutput(out_shape, DType.INT32) # Create the attribute with the names of the then/else blocks then_block = "THEN_BLOCK" else_block = "ELSE_BLOCK" attr = ts.TosaSerializerAttribute() attr.CondIfAttribute(then_block, else_block) # Finally, build the op and the two blocks self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr) self.ser.startBasicBlock(then_block) # Build the actual then/else tensors inside their blocks if error_name == ErrorIf.CondIfOutputListThenGraphMismatch: then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) else: then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr) self.ser.addOutputTensor(then_tens) self.ser.startBasicBlock(else_block) if error_name == ErrorIf.CondIfOutputListElseGraphMismatch: else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) else: else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr) self.ser.addOutputTensor(else_tens) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, basicBlocks=self.ser.basicBlocks, ): return None return result_tens def build_cond_if_binary( self, op, a, b, cond, validator_fcns=None, error_name=None ): # For cond_if with a binary op in the then/else blocks, take a and b and # alternately add or subtract them based on the condition # Condition tensor cond_tens = self.ser.addConst([], DType.BOOL, [cond]) result_tens = self.ser.addOutput(a.shape, a.dtype) # Create the attribute with the names of the then/else blocks then_block = "THEN_BLOCK" else_block = "ELSE_BLOCK" attr = ts.TosaSerializerAttribute() attr.CondIfAttribute(then_block, else_block) if error_name in [ ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch, ]: incorrect_shape = a.shape.copy() for i in range(len(incorrect_shape)): incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3]) incorrect_block_input = deepcopy(a) incorrect_block_input.shape = incorrect_shape # Finally, build the op and the two blocks self.ser.addOperator( op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) if a.dtype in (DType.FLOAT, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT else: assert False, f"No tests for DType: {a.dtype}" for block, op in ((then_block, then_op), (else_block, else_op)): self.ser.startBasicBlock(block) if ( error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block ) or ( error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block ): self.ser.addInputTensor(incorrect_block_input) self.ser.addInputTensor(b) tens = self.ser.addOutput(a.shape, a.dtype) elif ( error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block ) or ( error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block ): self.ser.addInputTensor(a) self.ser.addInputTensor(b) tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype) else: self.ser.addInputTensor(a) self.ser.addInputTensor(b) tens = self.ser.addOutput(a.shape, a.dtype) self.ser.addOperator(op, [a.name, b.name], [tens.name]) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, a=a, b=b, basicBlocks=self.ser.basicBlocks, ): return None return result_tens def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None): iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)]) cond_block = "COND_BLOCK" body_block = "BODY_BLOCK" attr = ts.TosaSerializerAttribute() attr.WhileLoopAttribute(cond_block, body_block) # Accumulator tensor # acc = self.ser.addOutput(a.shape, a.dtype) acc_init_val = np.int32(np.zeros(a.shape)) acc = self.ser.addPlaceholder(a.shape, a.dtype, acc_init_val) # Intermediate/output tensors for everything going through the loop iter_out = self.ser.addIntermediate(iter.shape, iter.dtype) a_out = self.ser.addIntermediate(a.shape, a.dtype) if error_name == ErrorIf.InputListOutputListMismatch: incorrect_acc = deepcopy(acc) for i in range(len(incorrect_acc.shape)): incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3]) acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype) else: acc_out = self.ser.addIntermediate(acc.shape, acc.dtype) # While_loop operator self.ser.addOperator( op["op"], [iter.name, a.name, acc.name], [iter_out.name, a_out.name, acc_out.name], attr, ) self.ser.addOutputTensor(acc_out) if error_name in [ ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch, ]: incorrect_iter = deepcopy(iter) for i in range(len(incorrect_iter.shape)): incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3]) if len(incorrect_iter.shape) == 0: incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3])) incorrect_acc = deepcopy(acc) for i in range(len(incorrect_acc.shape)): incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3]) # COND block (input: iter, output: cond_tens ) self.ser.startBasicBlock(cond_block) if error_name == ErrorIf.InputListCondGraphMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) self.ser.addInputTensor(incorrect_acc) else: self.ser.addInputTensor(iter) self.ser.addInputTensor(a) self.ser.addInputTensor(acc) zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)]) if error_name == ErrorIf.CondGraphOutputNotMatchingBool: cond_tens = self.ser.addOutput( [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]) ) else: cond_tens = self.ser.addOutput([], DType.BOOL) self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name]) # BODY block (input: a, acc, iter, output: a, acc, iter) # Note that local intermediate tensors need to be declared here for the outputs self.ser.startBasicBlock(body_block) if error_name == ErrorIf.InputListBodyGraphInputMismatch: self.ser.addInputTensor(incorrect_iter) self.ser.addInputTensor(a) self.ser.addInputTensor(incorrect_acc) else: self.ser.addInputTensor(iter) self.ser.addInputTensor(a) self.ser.addInputTensor(acc) one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)]) if error_name == ErrorIf.InputListBodyGraphOutputMismatch: iter_body_out = self.ser.addIntermediate( incorrect_iter.shape, incorrect_iter.dtype ) acc_body_out = self.ser.addIntermediate( incorrect_acc.shape, incorrect_acc.dtype ) else: iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype) acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype) self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name]) self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name]) self.ser.addOutputTensor(iter_body_out) self.ser.addOutputTensor(a) self.ser.addOutputTensor(acc_body_out) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, basicBlocks=self.ser.basicBlocks, ): return None return acc_out def create_filter_lists( self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None ): # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small. default_test_rank_range = range(1, 5) if not shapeFilter: shapeFilter = [None] # Calculate the filters based on what is requested and what the operator allows rmin, rmax = op["rank"] if rankFilter is not None: cleanRankFilter = [] # Ensure rankFilter values are allowed by operator for rank in rankFilter: if rank >= rmin and rank <= rmax: cleanRankFilter.append(rank) elif rankFilter is None and shapeFilter[0] is None: # Ensure default behaviour is bounded by default range or by operator, # whichever is the smaller range of ranks. opRankRange = range(rmin, rmax + 1) cleanRankFilter = ( opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range ) else: cleanRankFilter = range(rmin, rmax + 1) dtypes = op["types"] if dtypeFilter is not None: cleanDtypeFilter = [] # Create list of operator dtypes filtered by requested dtypes for dtype in dtypes: if dtype in dtypeFilter or ( isinstance(dtype, list) and dtype[0] in dtypeFilter ): cleanDtypeFilter.append(dtype) else: cleanDtypeFilter = dtypes if testType == "positive": filterDict = { "shapeFilter": shapeFilter, "rankFilter": cleanRankFilter, "dtypeFilter": cleanDtypeFilter, } return filterDict elif testType == "negative": if validator is not None: validator_info = validator(check=False, op=op) else: return None error_arguments = validator_info["param_reqs"] # Set parameters as required if error_arguments["rank"] is not None: rankFilter = error_arguments["rank"] else: rankFilter = cleanRankFilter if error_arguments["dtype"] is not None: dtypeFilter = error_arguments["dtype"] else: dtypeFilter = cleanDtypeFilter if error_arguments["shape"] is not None: shapeFilter = error_arguments["shape"] else: shapeFilter = shapeFilter[ :2 ] # Reduce number of shapes to keep test numbers small filterDict = { "shapeFilter": shapeFilter, "rankFilter": rankFilter, "dtypeFilter": dtypeFilter, } return filterDict def genOpTestList( self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType="positive", ): try: op = self.TOSA_OP_LIST[opName] except KeyError: raise Exception("Cannot find op with name {}".format(opName)) # Initialize a new random number generator self.rng = np.random.default_rng(self.random_seed) build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"] # Test list consists of a tuple of: # (opName, testNameStr, dtype, shapeList, argumentsList) testList = [] if testType == "negative" and "error_if_validators" in op: error_if_validators = op["error_if_validators"] else: error_if_validators = [None] for validator in error_if_validators: if validator is not None: error_name = validator(check=False, op=op)["error_name"] else: error_name = None filterDict = self.create_filter_lists( op, shapeFilter, rankFilter, dtypeFilter, testType, validator ) if filterDict is None: return [] cleanRankFilter = filterDict["rankFilter"] cleanDtypeFilter = filterDict["dtypeFilter"] cleanShapeFilter = filterDict["shapeFilter"] # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}") for r in cleanRankFilter: for t in cleanDtypeFilter: for shape in cleanShapeFilter: # Filter out by rank if shape is not None and len(shape) != r: continue self.setTargetShape(shape) shapeList = tgen_fcn(self, op, r, error_name) shapeStr = self.shapeStr(shapeList[0]) typeStr = self.typeStr(t) # Argument lists consists of tuples of the (str, []) string representation and the build function argument list argList = [] if agen_fcn: argList = agen_fcn(self, opName, shapeList, t, error_name) else: argList = [("", [])] for argStr, args in argList: if testType == "positive": if argStr: testStr = "{}_{}_{}_{}".format( opName, shapeStr, typeStr, argStr ) else: testStr = "{}_{}_{}".format( opName, shapeStr, typeStr ) elif testType == "negative": if argStr: testStr = "{}_ERRORIF_{}_{}_{}_{}".format( opName, error_name, shapeStr, typeStr, argStr ) else: testStr = "{}_ERRORIF_{}_{}_{}".format( opName, error_name, shapeStr, typeStr ) testList.append( (opName, testStr, t, error_name, shapeList, args) ) if testType == "positive": # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement if "invalid_test_validators" in op: invalid_test_validators = op["invalid_test_validators"] clean_testList = [] for test in testList: for validator_fcn in invalid_test_validators: remove_test = False if validator_fcn( opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5], ): remove_test = True if not remove_test: clean_testList.append(test) testList = clean_testList return testList def serializeTest( self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs ): try: op = self.TOSA_OP_LIST[opName] except KeyError: raise Exception("Cannot find op with name {}".format(opName)) # Create a serializer self.createSerializer(opName, testStr) build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"] if "error_if_validators" in op: error_if_validators = op["error_if_validators"] else: error_if_validators = None pCount, cCount = op["operands"] num_operands = pCount + cCount if isinstance(dtype_or_dtypeList, list): dtypeList = dtype_or_dtypeList elif op["op"] == Op.CONCAT: dtypeList = [dtype_or_dtypeList] * len(shapeList) else: dtypeList = [dtype_or_dtypeList] * (num_operands) if op["op"] != Op.CONCAT: assert ( len(shapeList) == num_operands ), "shapeList length {} must match number of operands {}".format( len(shapeList), num_operands ) assert ( len(dtypeList) == num_operands ), "dtypeList length {} must match number of operands {}".format( len(dtypeList), num_operands ) try: qgen = op["qgen"] except KeyError: qgen = None # Build the random tensor operands and the test tens = [] if qgen is not None: qinfo = qgen(self, op, dtype_or_dtypeList, error_name) else: qinfo = None tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name) try: if error_if_validators is None: if qinfo is not None: resultName = build_fcn(self, op, *tens, *testArgs, qinfo) else: resultName = build_fcn(self, op, *tens, *testArgs) else: if qinfo is not None: resultName = build_fcn( self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo, ) else: resultName = build_fcn( self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, ) except TypeError as e: print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n") raise e if resultName: # The test is valid, serialize it self.serialize("test") else: # The test is not valid print(f"Invalid ERROR_IF test created: {opName} {testStr}") def createDynamicOpLists(self): # Dynamically create op lists for convolutions with a list of kernel sizes KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]] for k in KERNELS_2D: testName = "conv2d_{}x{}".format(k[0], k[1]) self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy() self.TOSA_OP_LIST[testName]["filter"] = k self.TOSA_OP_LIST[testName]["template"] = False testName = "depthwise_conv2d_{}x{}".format(k[0], k[1]) self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[ "depthwise_conv2d_TEMPLATE" ].copy() self.TOSA_OP_LIST[testName]["filter"] = k self.TOSA_OP_LIST[testName]["template"] = False testName = "transpose_conv2d_{}x{}".format(k[0], k[1]) self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[ "transpose_conv2d_TEMPLATE" ].copy() self.TOSA_OP_LIST[testName]["filter"] = k self.TOSA_OP_LIST[testName]["template"] = False KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]] for k in KERNELS_3D: testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2]) self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy() self.TOSA_OP_LIST[testName]["filter"] = k self.TOSA_OP_LIST[testName]["template"] = False # Delete any templates after having created any dynamic ops # This is a two-pass operation because it's bad practice to delete # keys from dictionaries while iterating keyList = [] for k in self.TOSA_OP_LIST: try: if self.TOSA_OP_LIST[k]["template"]: keyList.append(k) continue except KeyError: pass for k in keyList: del self.TOSA_OP_LIST[k] def initOpListDefaults(self): """Fill in default fields for ops if they aren't already specified. Look for missing required fields (datastructure linting).""" for op in self.TOSA_OP_LIST: # Required fields try: pl, c = self.TOSA_OP_LIST[op]["operands"] except (KeyError, ValueError, TypeError): raise Exception( "Op {} is missing a valid operand tuple in TOSA_OP_LIST".format(op) ) try: fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"] except (KeyError, ValueError, TypeError): raise Exception( "Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format( op ) ) try: _ = self.TOSA_OP_LIST[op]["types"] except KeyError: raise Exception( "Op {} is missing a valid type list in TOSA_OP_LIST".format(op) ) try: _ = self.TOSA_OP_LIST[op]["op"] except KeyError: raise Exception( "Op {} is missing the Op field in TOSA_OP_LIST".format(op) ) # Put in default rank range, if missing try: _ = self.TOSA_OP_LIST[op]["rank"] except KeyError: self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE # Tensor operator list # 'op': op name # 'operands': tuple of (placeholder, const) operands # 'rank': optional, restricts rank to tuple inclusive of (min, max), # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested TYPE_FP = [DType.FLOAT] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4 TYPE_BOOL = [DType.BOOL] TYPE_FI32 = [DType.FLOAT, DType.INT32] TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL] TYPE_FI16 = [DType.FLOAT, DType.INT16] TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT] TYPE_CONV = [ [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], DType.FLOAT, ] DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK) TOSA_OP_LIST = { # Tensor operators "argmax": { "op": Op.ARGMAX, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_argmax, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agAxis, ), "types": TYPE_NARROW_INT_FP, "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch, TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "avg_pool2d": { "op": Op.AVG_POOL2D, "operands": (1, 0), "rank": (4, 4), "build_fcn": ( build_pool2d, TosaTensorGen.tgNHWC, TosaTensorValuesGen.tvgDefault, TosaArgGen.agPooling, ), "qgen": TosaQuantGen.qgUnary, "types": TYPE_NARROW_INT_FP, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch, TosaErrorValidator.evPoolingOutputShapeNonInteger, ), }, # Templated operator. Filled in by createDynamicOpLists "conv2d_TEMPLATE": { "op": Op.CONV2D, "operands": (1, 2), "rank": (4, 4), "build_fcn": ( build_conv2d, TosaTensorGen.tgConv2D, TosaTensorValuesGen.tvgDefault, TosaArgGen.agConv, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evPadSmallerZero, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evDilationSmallerOne, TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, ), "template": True, }, # Templated operator. Filled in by createDynamicOpLists "conv3d_TEMPLATE": { "op": Op.CONV3D, "operands": (1, 2), "rank": (5, 5), "build_fcn": ( build_conv3d, TosaTensorGen.tgConv3D, TosaTensorValuesGen.tvgDefault, TosaArgGen.agConv, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evPadSmallerZero, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evDilationSmallerOne, TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, ), "template": True, }, # Templated operator. Filled in by createDynamicOpLists "depthwise_conv2d_TEMPLATE": { "op": Op.DEPTHWISE_CONV2D, "operands": (1, 2), "filter": [1, 1], "rank": (4, 4), "build_fcn": ( build_depthwise_conv2d, TosaTensorGen.tgDepthwiseConv2D, TosaTensorValuesGen.tvgDefault, TosaArgGen.agConv, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evPadSmallerZero, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evDilationSmallerOne, TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, ), "template": True, }, "fully_connected": { "op": Op.FULLY_CONNECTED, "operands": (1, 2), "rank": (2, 2), "build_fcn": ( build_fully_connected, TosaTensorGen.tgFullyConnected, TosaTensorValuesGen.tvgDefault, None, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "matmul": { "op": Op.MATMUL, "operands": (2, 0), "rank": (3, 3), "build_fcn": ( build_matmul, TosaTensorGen.tgMatmul, TosaTensorValuesGen.tvgDefault, None, ), "qgen": TosaQuantGen.qgMatmul, "types": TYPE_NARROW_INT_FP, "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "max_pool2d": { "op": Op.MAX_POOL2D, "operands": (1, 0), "rank": (4, 4), "build_fcn": ( build_pool2d, TosaTensorGen.tgNHWC, TosaTensorValuesGen.tvgDefault, TosaArgGen.agPooling, ), "types": TYPE_NARROW_INT_FP, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch, TosaErrorValidator.evPoolingOutputShapeNonInteger, ), }, # Templated operator. Filled in by createDynamicOpLists "transpose_conv2d_TEMPLATE": { "op": Op.TRANSPOSE_CONV2D, "operands": (1, 2), "rank": (4, 4), "build_fcn": ( build_transpose_conv2d, TosaTensorGen.tgTransposeConv2D, TosaTensorValuesGen.tvgDefault, TosaArgGen.agTransposeConv2D, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, "invalid_test_validators": ( TosaInvalidValidator.ivHeightWidthInvalid, TosaInvalidValidator.ivNonPositiveOutputShape, ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evPadSmallerZero, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, ), "template": True, }, # Activation functions "clamp": { "op": Op.CLAMP, "operands": (1, 0), "build_fcn": ( build_clamp, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_NARROW_INT_FP, "error_if_validators": ( TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "sigmoid": { "op": Op.SIGMOID, "operands": (1, 0), "build_fcn": ( build_sigmoid, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "tanh": { "op": Op.TANH, "operands": (1, 0), "build_fcn": ( build_tanh, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, # Elementwise Binary Operators "add": { "op": Op.ADD, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgAddSub, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "arithmetic_right_shift": { "op": Op.ARITHMETIC_RIGHT_SHIFT, "operands": (2, 0), "build_fcn": ( build_arithmetic_right_shift, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgArithmeticRightShift, TosaArgGen.agArithmeticRightShift, ), "types": TYPE_INT, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "bitwise_and": { "op": Op.BITWISE_AND, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_INT, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "bitwise_or": { "op": Op.BITWISE_OR, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_INT, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "bitwise_xor": { "op": Op.BITWISE_XOR, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_INT, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "intdiv": { "op": Op.INTDIV, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgIntDiv, None, ), "types": [DType.INT32], "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "logical_and": { "op": Op.LOGICAL_AND, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_BOOL, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "logical_left_shift": { "op": Op.LOGICAL_LEFT_SHIFT, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgLogicalShift, None, ), "types": TYPE_INT, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "logical_right_shift": { "op": Op.LOGICAL_RIGHT_SHIFT, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgLogicalShift, None, ), "types": TYPE_INT, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "logical_or": { "op": Op.LOGICAL_OR, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_BOOL, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "logical_xor": { "op": Op.LOGICAL_XOR, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_BOOL, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "maximum": { "op": Op.MAXIMUM, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "minimum": { "op": Op.MINIMUM, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "mul": { "op": Op.MUL, "operands": (2, 0), "build_fcn": ( build_mul, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgMul, TosaArgGen.agMul, ), "types": TYPE_INT_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch, ), }, "pow": { "op": Op.POW, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "sub": { "op": Op.SUB, "operands": (2, 0), "build_fcn": ( build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgAddSub, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "table": { "op": Op.TABLE, # Use the automatic generation functions to create the input array # but create the table tensor in the build function, as it may be # a different type from the input "operands": (1, 0), "build_fcn": ( build_table, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agTable, ), "types": [DType.INT8, DType.INT16], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, # Elementwise Unary operators "abs": { "op": Op.ABS, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "bitwise_not": { "op": Op.BITWISE_NOT, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_INT, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "ceil": { "op": Op.CEIL, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "clz": { "op": Op.CLZ, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": [DType.INT32], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "exp": { "op": Op.EXP, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "floor": { "op": Op.FLOOR, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "log": { "op": Op.LOG, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "logical_not": { "op": Op.LOGICAL_NOT, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_BOOL, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "negate": { "op": Op.NEGATE, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgNegate, None, ), "qgen": TosaQuantGen.qgUnary, "types": TYPE_INT_FP, "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reciprocal": { "op": Op.RECIPROCAL, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "rsqrt": { "op": Op.RSQRT, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, # Elementwise Ternary operators "select": { "op": Op.SELECT, "operands": (3, 0), "build_fcn": ( build_select, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgSelect, None, ), "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, # Comparison operators "equal": { "op": Op.EQUAL, "operands": (2, 0), "build_fcn": ( build_comparison, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgEqual, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "greater_equal": { "op": Op.GREATER_EQUAL, "operands": (2, 0), "build_fcn": ( build_comparison, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, "greater": { "op": Op.GREATER, "operands": (2, 0), "build_fcn": ( build_comparison, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, ), }, # Reduction operators "reduce_all": { "op": Op.REDUCE_ALL, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agAxis, ), "types": TYPE_BOOL, "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reduce_any": { "op": Op.REDUCE_ANY, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agAxis, ), "types": TYPE_BOOL, "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reduce_max": { "op": Op.REDUCE_MAX, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agAxis, ), "types": TYPE_INT_FP, "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reduce_min": { "op": Op.REDUCE_MIN, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agAxis, ), "types": TYPE_INT_FP, "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reduce_product": { "op": Op.REDUCE_PRODUCT, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agAxis, ), "types": TYPE_FP, "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reduce_sum": { "op": Op.REDUCE_SUM, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), "types": TYPE_FI32, "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, # Data layout operators "concat": { "op": Op.CONCAT, "operands": (2, 0), "build_fcn": ( build_concat, TosaTensorGen.tgConcat, TosaTensorValuesGen.tvgConcat, TosaArgGen.agAxis, ), "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch, TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList, ), }, "pad": { "op": Op.PAD, "operands": (1, 0), "rank": (1, 5), "build_fcn": ( build_pad, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agPad, ), "qgen": TosaQuantGen.qgPad, "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reshape": { "op": Op.RESHAPE, "operands": (1, 0), "build_fcn": ( build_reshape, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agReshape, ), "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "reverse": { "op": Op.REVERSE, "operands": (1, 0), "build_fcn": ( build_reverse, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agAxis, ), "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "slice": { "op": Op.SLICE, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_slice, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agSlice, ), "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds, TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "tile": { "op": Op.TILE, "operands": (1, 0), "build_fcn": ( build_tile, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agTile, ), "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "transpose": { "op": Op.TRANSPOSE, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_transpose, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agTranspose, ), "types": TYPE_FIB, "error_if_validators": ( TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, # Data nodes "const": { "op": Op.CONST, "operands": (0, 1), "build_fcn": ( build_const, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FIB, }, "identity": { "op": Op.IDENTITY, "operands": (1, 0), "build_fcn": ( build_unary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_FIB, }, # Scatter/Gather "gather": { "op": Op.GATHER, # Only specify 'values' tensor here. 'indices' is generated in op building stage "operands": (1, 0), "rank": (3, 3), "build_fcn": ( build_gather, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_INT_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank, ), }, "scatter": { "op": Op.SCATTER, # Only specify 'values_in' tensor here. # 'indices' and 'input' are generated in op building stage "operands": (2, 0), "rank": (3, 3), "build_fcn": ( build_scatter, TosaTensorGen.tgScatter, TosaTensorValuesGen.tvgDefault, None, ), "types": TYPE_INT_FP, "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank, ), }, # Image operations "resize": { "op": Op.RESIZE, "operands": (1, 0), "rank": (4, 4), "build_fcn": ( build_resize, TosaTensorGen.tgNHWC, TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), "types": [DType.INT8, DType.INT16, DType.FLOAT], "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride, ), "error_if_validators": ( TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension, TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax, TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch, ), }, # Type conversion "cast": { "op": Op.CAST, "operands": (1, 0), "build_fcn": ( build_cast, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agCast, ), "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, "rescale": { "op": Op.RESCALE, "operands": (1, 0), "rank": (1, 4), "build_fcn": ( build_rescale, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgDefault, TosaArgGen.agRescale, ), "types": [ DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.UINT16, ], "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evU16InputZeroPointNotValid, TosaErrorValidator.evU16OutputZeroPointNotValid, TosaErrorValidator.evScaleTrue, TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), }, # Custom # Not implemented. # Control flow operators # Two varients of cond_if, one that generates one of two constant tensors (no # inputs to the basic blocks, one output) and another that either adds or subtracts two tensors # (two inputs to the basic blocks, one output) "cond_if_const": { "op": Op.COND_IF, "operands": (0, 2), "build_fcn": ( build_cond_if_const, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgCondIfWhileLoop, TosaArgGen.agCondIf, ), "types": [DType.BOOL], "error_if_validators": ( TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch, ), }, "cond_if_binary": { "op": Op.COND_IF, "operands": (2, 0), "build_fcn": ( build_cond_if_binary, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgCondIfWhileLoop, TosaArgGen.agCondIf, ), "types": TYPE_INT_FP, "error_if_validators": ( TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch, TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch, ), }, # while_loop "while_loop": { "op": Op.WHILE_LOOP, "operands": (0, 1), "build_fcn": ( build_while_loop, TosaTensorGen.tgBasic, TosaTensorValuesGen.tvgCondIfWhileLoop, TosaArgGen.agWhileLoop, ), "types": [DType.INT32], "error_if_validators": ( TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch, TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch, TosaErrorValidator.evCondGraphOutputNotMatchingBool, ), }, } class OutputShaper: # Methods in this class compute the expected output shape and datatype # for common classes of operations def __init__(self): pass # These methods return arguments that can be used for # creating a new output tensor @staticmethod def binaryBroadcastOp(ser, rng, a, b, error_name=None): if error_name != ErrorIf.RankMismatch: assert len(a.shape) == len(b.shape) assert a.dtype == b.dtype shape = [] for i in range(len(a.shape)): if a.shape[i] == 1 and error_name is None: shape.append(b.shape[i]) else: shape.append(a.shape[i]) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(shape, outputDType) @staticmethod def binaryNonBroadcastOp(ser, a, b): assert len(a.shape) == len(b.shape) assert a.dtype == b.dtype shape = [] for i in range(len(a.shape)): assert a.shape[i] == b.shape[i] shape.append(a.shape[i]) return ser.addOutput(shape, a.dtype) @staticmethod def unaryOp(ser, rng, a, error_name=None): if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(a.shape, outputDType) @staticmethod def selectOp(ser, rng, cond, a, b, error_name=None): if error_name != ErrorIf.RankMismatch: assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape) assert a.dtype == b.dtype shape = [] for i in range(len(cond.shape)): if cond.shape[i] == 1 and error_name is None: shape.append(max(cond.shape[i], a.shape[i], b.shape[i])) else: shape.append(cond.shape[i]) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(shape, outputDType) @staticmethod def binaryComparisonOp(ser, rng, a, b, error_name=None): if error_name != ErrorIf.RankMismatch: assert len(a.shape) == len(b.shape) assert a.dtype == b.dtype # Do broadcast shape = [] for i in range(len(a.shape)): if a.shape[i] == 1 and len(b.shape) > i: shape.append(b.shape[i]) else: shape.append(a.shape[i]) if error_name == ErrorIf.WrongOutputType: wrong_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] outputDType = rng.choice(wrong_dtypes) else: outputDType = DType.BOOL return ser.addOutput(shape, outputDType) @staticmethod def reduceOp(ser, rng, a, axis, error_name=None): shape = a.shape.copy() if error_name not in [ ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne, ]: shape[axis] = 1 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1: shape[axis] = rng.integers(2, 10) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(shape, outputDType) @staticmethod def argmaxOp(ser, rng, a, axis, error_name=None): shape = a.shape.copy() if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]: del shape[axis] if error_name == ErrorIf.ArgmaxOutputRankMismatch: remove = rng.choice([True, False]) if remove and len(shape) > 1: del shape[0] else: shape.append(1) elif error_name == ErrorIf.ArgmaxOutputShapeMismatch: for i in range(len(shape)): shape[i] = shape[i] + rng.integers(1, 10) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) else: outputDType = DType.INT32 return ser.addOutput(shape, outputDType) @staticmethod def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): # IFM: NHWC # Filter: OHWI # OFM: NHWC h = ( ifm.shape[1] - 1 + padding[0] + padding[1] - (filter.shape[1] - 1) * dilations[0] ) // strides[0] + 1 w = ( ifm.shape[2] - 1 + padding[2] + padding[3] - (filter.shape[2] - 1) * dilations[1] ) // strides[1] + 1 if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) # increment in multiples of stride to not hit non-integer error case if change in [1, 3]: h = h + (rng.choice(choices) * strides[0]) if change in [2, 3]: w = w + (rng.choice(choices) * strides[1]) ofm_shape = [ifm.shape[0], h, w, filter.shape[0]] if ifm.dtype == DType.INT8: out_dtype = DType.INT32 elif ifm.dtype == DType.INT16: out_dtype = DType.INT48 elif ifm.dtype == DType.FLOAT: out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: raise Exception(f"Unsupported input dtype: {ifm.dtype}") if error_name == ErrorIf.WrongOutputType: wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): # IFM: NDHWC # Filter: ODHWI # OFM: NDHWC d = ( ifm.shape[1] - 1 + padding[0] + padding[1] - (filter.shape[1] - 1) * dilations[0] ) // strides[0] + 1 h = ( ifm.shape[2] - 1 + padding[2] + padding[3] - (filter.shape[2] - 1) * dilations[1] ) // strides[1] + 1 w = ( ifm.shape[3] - 1 + padding[4] + padding[5] - (filter.shape[3] - 1) * dilations[2] ) // strides[2] + 1 if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3, 4] change = rng.choice(choices) # increment in multiples of stride to not hit non-integer error case if change in [1, 4]: d = d + (rng.choice(choices) * strides[0]) if change in [2, 4]: h = h + (rng.choice(choices) * strides[1]) if change in [3, 4]: w = w + (rng.choice(choices) * strides[2]) ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]] if ifm.dtype == DType.INT8: out_dtype = DType.INT32 elif ifm.dtype == DType.INT16: out_dtype = DType.INT48 elif ifm.dtype == DType.FLOAT: out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: raise Exception(f"Unsupported input dtype: {ifm.dtype}") if error_name == ErrorIf.WrongOutputType: wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod def depthwiseConv2dOp( ser, rng, ifm, filter, strides, padding, dilations, error_name=None ): # IFM: NHWC # Filter: HWCM # OFM: NHW C*M h = ( ifm.shape[1] - 1 + padding[0] + padding[1] - (filter.shape[0] - 1) * dilations[0] ) // strides[0] + 1 w = ( ifm.shape[2] - 1 + padding[2] + padding[3] - (filter.shape[1] - 1) * dilations[1] ) // strides[1] + 1 if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) # increment in multiples of stride to not hit non-integer error case if change in [1, 3]: h = h + (rng.choice(choices) * strides[0]) if change in [2, 3]: w = w + (rng.choice(choices) * strides[1]) ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]] if ifm.dtype == DType.INT8: out_dtype = DType.INT32 elif ifm.dtype == DType.INT16: out_dtype = DType.INT48 elif ifm.dtype == DType.FLOAT: out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: raise Exception(f"Unsupported input dtype: {ifm.dtype}") if error_name == ErrorIf.WrongOutputType: wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None): # input: NHWC if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0: # If an incorrect stride is used set dimensions to 1, test is invalid anyway. h = 1 w = 1 else: h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1 w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1 if error_name == ErrorIf.PoolingOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) # increment in multiples of stride to not hit non-integer error case if change in [1, 3]: h = h + (rng.choice(choices) * stride[0]) if change in [2, 3]: w = w + (rng.choice(choices) * stride[1]) ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]] if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = ifm.dtype return ser.addOutput(ofm_shape, outputDType) @staticmethod def fullyConnectedOp(ser, rng, input, filter, error_name=None): # input: N, IC # filter: OC, IC # output: N, OC output_shape = [input.shape[0], filter.shape[0]] if error_name == ErrorIf.WrongOutputType: if input.dtype == DType.INT8: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT, ) elif input.dtype == DType.INT16: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT, ) elif input.dtype == DType.FLOAT: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, ) out_dtype = rng.choice(a=incorrect_types) elif input.dtype == DType.INT8: out_dtype = DType.INT32 elif input.dtype == DType.INT16: out_dtype = DType.INT48 elif input.dtype == DType.FLOAT: out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: raise Exception("Unsupported input dtype: {}".format(input.dtype)) return ser.addOutput(output_shape, out_dtype) @staticmethod def matmulOp(ser, rng, a, b, error_name=None): # a: N, H, C # b: N, C, W # out: N, H, W output_shape = [a.shape[0], a.shape[1], b.shape[2]] if error_name == ErrorIf.WrongOutputType: if a.dtype == DType.INT8: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT, ) elif a.dtype == DType.INT16: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT, ) elif a.dtype == DType.FLOAT: incorrect_types = ( DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, ) out_dtype = rng.choice(a=incorrect_types) elif a.dtype == DType.INT8: out_dtype = DType.INT32 elif a.dtype == DType.INT16: out_dtype = DType.INT48 elif a.dtype == DType.FLOAT: out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype)) return ser.addOutput(output_shape, out_dtype) @staticmethod def concatOp(ser, rng, axis, *a, error_name=None): input1 = a[0] remaining_inputs = a[1:] # calculate the output shape, if possible, otherwise just use the first input shape output_shape = input1.shape.copy() if not ( # unable to concat tensors of different ranks error_name == ErrorIf.ConcatInputRankMismatch # unable to concat tensors along an invalid axis or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero] ): for tensor in remaining_inputs: output_shape[axis] += tensor.shape[axis] if error_name == ErrorIf.ConcatShapeSumMismatch: output_shape[axis] += rng.integers(5, 10) if error_name == ErrorIf.WrongOutputType: all_dtypes = { DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, } wrong_dtypes = list(all_dtypes - set([input1.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = input1.dtype return ser.addOutput(output_shape, outputDType) @staticmethod def padOp(ser, rng, a, padding, error_name=None): output_shape = a.shape.copy() for i in range(len(output_shape)): output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i] # Fix negative output shape if error_if test causes it if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1: output_shape = [i if i >= 1 else 1 for i in output_shape] if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(output_shape, outputDType) @staticmethod def reshapeOp(ser, rng, a, shape, error_name=None): output_shape = shape.copy() if error_name == ErrorIf.TensorSizeInputOutputMismatch: for i in range(len(output_shape)): output_shape[i] = output_shape[i] + rng.integers(1, 10) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(output_shape, outputDType) @staticmethod def sliceOp(ser, rng, a, start, size, error_name=None): if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype if error_name == ErrorIf.SizeOutputShapeMismatch: output_shape = size.copy() for index in range(len(output_shape)): if output_shape[index] <= 2: output_shape[index] = output_shape[index] + rng.choice([1, 2]) else: output_shape[index] = output_shape[index] + rng.choice( [-2, -1, 1, 2] ) else: output_shape = size.copy() return ser.addOutput(output_shape, outputDType) @staticmethod def tileOp(ser, rng, a, multiples, error_name=None): output_shape = a.shape.copy() assert len(multiples) == len(output_shape) for i in range(len(output_shape)): output_shape[i] = a.shape[i] * multiples[i] if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(output_shape, outputDType) @staticmethod def transposeOp(ser, rng, a, perms, error_name=None): output_shape = a.shape.copy() assert len(perms) == len(output_shape) if error_name == ErrorIf.IndexOutsideBounds: for i in range(len(output_shape)): output_shape[i] = a.shape[0] else: for i in range(len(output_shape)): output_shape[i] = a.shape[perms[i]] if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = a.dtype return ser.addOutput(output_shape, outputDType) @staticmethod def gatherOp(ser, rng, values, indices, error_name=None): if error_name != ErrorIf.WrongRank: assert len(values.shape) == 3 assert len(indices.shape) == 2 assert values.shape[0] == indices.shape[0] output_shape = [values.shape[0], indices.shape[1], values.shape[2]] if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([values.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = values.dtype return ser.addOutput(output_shape, outputDType) @staticmethod def scatterOp(ser, rng, values_in, indices, input, error_name=None): if error_name != ErrorIf.WrongRank: assert len(values_in.shape) == 3 assert len(indices.shape) == 2 assert len(input.shape) == 3 assert values_in.shape[0] == indices.shape[0] # N assert input.shape[1] == indices.shape[1] # W assert values_in.shape[2] == input.shape[2] # C output_shape = values_in.shape if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) else: outputDType = values_in.dtype return ser.addOutput(output_shape, outputDType) @staticmethod def tableOp(ser, rng, input, error_name=None): # Same shape as the input, dtype dependent on input dtype if error_name != ErrorIf.WrongInputType: assert input.dtype == DType.INT16 or input.dtype == DType.INT8 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8 if error_name == ErrorIf.WrongOutputType: wrong_dtypes = [ DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT, ] wrong_dtypes.remove(output_dtype) output_dtype = rng.choice(wrong_dtypes) return ser.addOutput(input.shape, output_dtype) @staticmethod def resizeOp( serializer, rng, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype, error_name=None, ): if error_name == ErrorIf.WrongRank: output_dims = [ input.shape[0], output_dims[0], output_dims[0], input.shape[0], ] else: if error_name == ErrorIf.BatchMismatch: output_dims = [ input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3], ] elif error_name == ErrorIf.ChannelMismatch: output_dims = [ input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10), ] else: output_dims = [ input.shape[0], output_dims[0], output_dims[1], input.shape[3], ] return serializer.addOutput(output_dims, output_dtype) @staticmethod def typeConversionOp(ser, rng, val, out_dtype, error_name=None): return ser.addOutput(val.shape, out_dtype) @staticmethod def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None): if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) if change in [1, 3]: output_shape[1] = output_shape[1] + rng.choice(choices) if change in [2, 3]: output_shape[2] = output_shape[2] + rng.choice(choices) if ifm.dtype == DType.INT8: out_dtype = DType.INT32 elif ifm.dtype == DType.INT16: out_dtype = DType.INT48 elif ifm.dtype == DType.FLOAT: out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: raise Exception(f"Unsupported input dtype: {ifm.dtype}") if error_name == ErrorIf.WrongOutputType: wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype)