diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 703 |
1 files changed, 379 insertions, 324 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 3173906..7702753 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -20,6 +20,8 @@ from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator +from generator.tosa_random_gen import TosaHashRandomGenerator +from generator.tosa_random_gen import TosaRandomGenerator from schemavalidation.schemavalidation import TestDescSchemaValidator from tosa.DType import DType from tosa.Op import Op @@ -50,10 +52,10 @@ class TosaTestGen: self.basePath = args.output_dir self.random_seed = args.random_seed self.ser = None - self.rng = np.random.default_rng(self.random_seed) self.createDynamicOpLists() self.initOpListDefaults() self.quantGen = TosaQuantGen() + self.global_rng = None # Force makeShape to do a specific starting shape self.targetted_shape = None # JSON schema validation @@ -80,12 +82,18 @@ class TosaTestGen: vals.append(v) return tuple(sorted(vals)) - self.random_float_range = {} + self.random_dtype_range = { + DType.SHAPE: tuple(self.args.tensor_shape_range[0:2]) + } for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): - self.random_float_range[dtype] = convertFPRange( + self.random_dtype_range[dtype] = convertFPRange( args.tensor_fp_value_range, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], ) + self.resetGlobalRNG() + + def resetGlobalRNG(self): + self.global_rng = TosaRandomGenerator(self.random_seed, self.random_dtype_range) def createSerializer(self, opName, testPath): self.testPath = os.path.join(opName, testPath) @@ -148,93 +156,7 @@ class TosaTestGen: with path_desc.open("w") as fd: json.dump(desc, fd, indent=1) - def resetRNG(self, seed=None): - if seed is None: - seed = self.random_seed + 1 - self.rng = np.random.default_rng(seed) - - def getDTypeRange(self, dtype, high_inclusive=False): - # Returns dtype value range boundaries (low, high) - # The high boundary is excluded in the range - # unless high_inclusive is True - if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): - return self.random_float_range[dtype] - elif dtype == DType.BOOL: - rng = (0, 2) - elif dtype == DType.UINT8: - rng = (0, 256) - elif dtype == DType.UINT16: - rng = (0, 65536) - elif dtype == DType.INT4: - # TOSA specific INT4 weight range from -7 to 7 - rng = (-7, 8) - elif dtype == DType.INT8: - rng = (-128, 128) - elif dtype == DType.INT16: - rng = (-32768, 32768) - elif dtype == DType.INT32: - rng = (-(1 << 31), (1 << 31)) - elif dtype == DType.SHAPE: - rng = tuple(self.args.tensor_shape_range[0:2]) - elif dtype == DType.INT48: - rng = (-(1 << 47), (1 << 47)) - else: - raise Exception("Unknown dtype: {}".format(dtype)) - - if not high_inclusive: - # Exclusive high: low <= range < high - return rng - else: - # Inclusive range: low <= range <= high - return (rng[0], rng[1] - 1) - - def getRandTensor(self, shape, dtype, data_range=None): - if data_range is None: - low, high = self.getDTypeRange(dtype) - else: - low, high = data_range - - if dtype == DType.BOOL: - return np.bool_(self.rng.choice(a=[False, True], size=shape)) - elif dtype == DType.INT4: - return np.int8(self.rng.integers(low=low, high=high, size=shape)) - elif dtype == DType.INT8: - return np.int8(self.rng.integers(low=low, high=high, size=shape)) - elif dtype == DType.UINT8: - return np.uint8(self.rng.integers(low=low, high=high, size=shape)) - elif dtype == DType.INT16: - return np.int16(self.rng.integers(low=low, high=high, size=shape)) - elif dtype == DType.UINT16: - return np.uint16(self.rng.integers(low=low, high=high, size=shape)) - elif dtype in (DType.INT48, DType.SHAPE): - return np.int64(self.rng.integers(low=low, high=high, size=shape)) - elif dtype in ( - DType.FP16, - DType.BF16, - DType.FP32, - DType.FP8E4M3, - DType.FP8E5M2, - ): - f_tensor = self.rng.uniform(low=low, high=high, size=shape) - - if dtype == DType.FP16: - return np.float16(f_tensor) - else: - f32_tensor = np.float32(f_tensor) - if dtype == DType.BF16: - # Floor the last 16 bits of each f32 value - return np.float32(gtu.vect_f32_to_bf16(f32_tensor)) - elif dtype == DType.FP8E4M3: - return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor)) - elif dtype == DType.FP8E5M2: - return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor)) - else: - return f32_tensor - else: - # All other integer types - return np.int32(self.rng.integers(low=low, high=high, size=shape)) - - def buildPlaceholderTensors(self, shape_list, dtype_list): + def buildPlaceholderTensors(self, rng, shape_list, dtype_list): placeholders = [] assert len(shape_list) == len(dtype_list) @@ -242,12 +164,12 @@ class TosaTestGen: arr = None for idx, shape in enumerate(shape_list): if not self.args.lazy_data_gen: - arr = self.getRandTensor(shape, dtype_list[idx]) + arr = rng.randTensor(shape, dtype_list[idx]) placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr)) return placeholders - def buildConstTensors(self, shape_list, dtype_list): + def buildConstTensors(self, rng, shape_list, dtype_list): consts = [] assert len(shape_list) == len(dtype_list) @@ -255,16 +177,16 @@ class TosaTestGen: arr = None for idx, shape in enumerate(shape_list): if not self.args.lazy_data_gen: - arr = self.getRandTensor(shape, dtype_list[idx]) + arr = rng.randTensor(shape, dtype_list[idx]) consts.append(self.ser.addConst(shape, dtype_list[idx], arr)) return consts - def makeShape(self, rank): + def makeShape(self, rng, rank): if self.targetted_shape: return np.int32(self.targetted_shape) return np.int32( - self.rng.integers( + rng.integers( low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1], size=rank, @@ -274,33 +196,6 @@ class TosaTestGen: def setTargetShape(self, shape): self.targetted_shape = shape - def randInt(self, low=0, high=256): - return np.int32(self.rng.integers(low=low, high=high, size=1))[0] - - def getRandNumberDType(self, dtype): - low, high = self.getDTypeRange(dtype) - - if dtype == DType.FP32: - return np.float32(self.rng.uniform(low=low, high=high)) - elif dtype == DType.FP16: - return np.float16(self.rng.uniform(low=low, high=high)) - elif dtype == DType.BF16: - rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) - return gtu.vect_f32_to_bf16(rand_f32) - elif dtype == DType.FP8E4M3: - rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) - return gtu.vect_f32_to_fp8e4m3(rand_f32) - elif dtype == DType.FP8E5M2: - rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) - return gtu.vect_f32_to_fp8e5m2(rand_f32) - elif dtype == DType.BOOL: - return self.rng.choice([False, True]) - elif dtype == DType.INT48 or dtype == DType.SHAPE: - # Special size - return np.int64(self.rng.integers(low, high, size=1))[0] - - return np.int32(self.rng.integers(low, high, size=1))[0] - def shapeStr(self, shape): sStr = [] @@ -330,8 +225,8 @@ class TosaTestGen: shape[0] = min(shape[0], self.args.max_batch_size) return shape - def makeDimension(self): - return self.randInt( + def makeDimension(self, rng): + return rng.randInt( low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1] ) @@ -445,11 +340,18 @@ class TosaTestGen: return compliance def build_unary( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 a = inputs[0] - result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) + result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name) assert not isinstance(op, int) @@ -457,8 +359,10 @@ class TosaTestGen: if error_name == ErrorIf.WrongOutputType: if result_tensor.dtype not in [DType.INT8, DType.UINT8]: qinfo = [ - TosaQuantGen.getZeroPoint(self, a.dtype), - TosaQuantGen.getZeroPoint(self, result_tensor.dtype), + TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, a.dtype), + TosaQuantGen.getZeroPoint( + rng, self.args.zeropoint, result_tensor.dtype + ), ] # Invalidate Input/Output list for error if checks. @@ -467,7 +371,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -498,13 +402,11 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_binary_broadcast( - self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None + self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None ): assert len(inputs) == 2 a, b = inputs - result_tensor = OutputShaper.binaryBroadcastOp( - self.ser, self.rng, a, b, error_name - ) + result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] @@ -512,7 +414,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -539,20 +441,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) - def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None): - result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b) - self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name]) - return result_tens - def build_arithmetic_right_shift( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 2 a, b = inputs round = args_dict["round"] - result_tensor = OutputShaper.binaryBroadcastOp( - self.ser, self.rng, a, b, error_name - ) + result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] @@ -560,7 +462,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -591,15 +493,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_mul( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): # Note that mul is binary operator but it has a shift value tensor assert len(inputs) == 3 a, b, s = inputs - result_tensor = OutputShaper.binaryBroadcastOp( - self.ser, self.rng, a, b, error_name - ) + result_tensor = OutputShaper.binaryBroadcastOp(self.ser, rng, a, b, error_name) # Special for multiply: Force the result to INT32 for INT types if a.dtype not in (DType.FP16, DType.BF16, DType.FP32): @@ -607,7 +514,7 @@ class TosaTestGen: if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] - outputDType = self.rng.choice(all_dtypes) + outputDType = rng.choice(all_dtypes) result_tensor.setDtype(outputDType) # Invalidate Input/Output list for error if checks. @@ -616,7 +523,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -644,12 +551,19 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_table( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 a = inputs[0] table = args_dict["table"] - result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name) + result_tensor = OutputShaper.tableOp(self.ser, rng, a, error_name) attr = ts.TosaSerializerAttribute() attr.TableAttribute(table) @@ -660,7 +574,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -687,14 +601,19 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_select( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 3 cond, a, b = inputs - result_tensor = OutputShaper.selectOp( - self.ser, self.rng, cond, a, b, error_name - ) + result_tensor = OutputShaper.selectOp(self.ser, rng, cond, a, b, error_name) # Invalidate Input/Output list for error if checks. input_list = [cond.name, a.name, b.name] @@ -702,7 +621,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -735,14 +654,19 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_comparison( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 2 a, b = inputs - result_tensor = OutputShaper.binaryComparisonOp( - self.ser, self.rng, a, b, error_name - ) + result_tensor = OutputShaper.binaryComparisonOp(self.ser, rng, a, b, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] @@ -750,7 +674,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -783,12 +707,12 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_argmax( - self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None + self, rng, op, inputs, args_dict, validator_fcns, error_name, qinfo=None ): assert len(inputs) == 1 a = inputs[0] axis = args_dict["axis"] - result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name) + result_tensor = OutputShaper.argmaxOp(self.ser, rng, a, axis, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -796,7 +720,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -828,6 +752,7 @@ class TosaTestGen: def build_pool2d( self, + rng, op, inputs, args_dict, @@ -846,15 +771,17 @@ class TosaTestGen: kernel = args_dict["kernel"] result_tensor = OutputShaper.pool2dOp( - self.ser, self.rng, input, kernel, stride, pad, error_name + self.ser, rng, input, kernel, stride, pad, error_name ) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType: if input.dtype not in [DType.INT8, DType.UINT8]: qinfo = [ - TosaQuantGen.getZeroPoint(self, input.dtype), - TosaQuantGen.getZeroPoint(self, result_tensor.dtype), + TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, input.dtype), + TosaQuantGen.getZeroPoint( + rng, self.args.zeropoint, result_tensor.dtype + ), ] # Invalidate Input/Output list for error if checks. @@ -863,7 +790,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -903,6 +830,7 @@ class TosaTestGen: def build_conv2d( self, + rng, op, inputs, args_dict, @@ -920,7 +848,7 @@ class TosaTestGen: assert len(padding) == 4 result_tensor = OutputShaper.conv2dOp( self.ser, - self.rng, + rng, ifm, filter, accum_dtype, @@ -936,8 +864,10 @@ class TosaTestGen: DType.UINT8, ): qinfo = [ - TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tensor.dtype), + TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype), + TosaQuantGen.getZeroPoint( + rng, self.args.zeropoint, result_tensor.dtype + ), ] # Invalidate Input/Output list for error_if checks. @@ -945,7 +875,7 @@ class TosaTestGen: output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -985,6 +915,7 @@ class TosaTestGen: def build_conv3d( self, + rng, op, inputs, args_dict, @@ -1002,7 +933,7 @@ class TosaTestGen: assert len(padding) == 6 result_tensor = OutputShaper.conv3dOp( self.ser, - self.rng, + rng, ifm, filter, accum_dtype, @@ -1018,8 +949,10 @@ class TosaTestGen: DType.UINT8, ): qinfo = [ - TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tensor.dtype), + TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype), + TosaQuantGen.getZeroPoint( + rng, self.args.zeropoint, result_tensor.dtype + ), ] # Invalidate Input/Output list for error_if checks. @@ -1027,7 +960,7 @@ class TosaTestGen: output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1067,6 +1000,7 @@ class TosaTestGen: def build_transpose_conv2d( self, + rng, op, inputs, args_dict, @@ -1083,7 +1017,7 @@ class TosaTestGen: assert len(out_pad) == 4 result_tensor = OutputShaper.transposeConv2DOp( - self.ser, self.rng, ifm, output_shape, accum_dtype, error_name + self.ser, rng, ifm, output_shape, accum_dtype, error_name ) # Ensure new output type has correct qinfo @@ -1092,8 +1026,10 @@ class TosaTestGen: DType.UINT8, ): qinfo = [ - TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tensor.dtype), + TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype), + TosaQuantGen.getZeroPoint( + rng, self.args.zeropoint, result_tensor.dtype + ), ] # Invalidate Input/Output list for error_if checks. @@ -1101,7 +1037,7 @@ class TosaTestGen: output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1142,6 +1078,7 @@ class TosaTestGen: def build_depthwise_conv2d( self, + rng, op, inputs, args_dict, @@ -1158,7 +1095,7 @@ class TosaTestGen: result_tensor = OutputShaper.depthwiseConv2dOp( self.ser, - self.rng, + rng, ifm, filter, accum_dtype, @@ -1174,8 +1111,10 @@ class TosaTestGen: DType.UINT8, ): qinfo = [ - TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tensor.dtype), + TosaQuantGen.getZeroPoint(rng, self.args.zeropoint, ifm.dtype), + TosaQuantGen.getZeroPoint( + rng, self.args.zeropoint, result_tensor.dtype + ), ] # Invalidate Input/Output list for error_if checks. @@ -1183,7 +1122,7 @@ class TosaTestGen: output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1223,6 +1162,7 @@ class TosaTestGen: def build_fully_connected( self, + rng, op, inputs, args_dict, @@ -1235,7 +1175,7 @@ class TosaTestGen: accum_dtype = args_dict["acc_type"] result_tensor = OutputShaper.fullyConnectedOp( - self.ser, self.rng, ifm, filter, accum_dtype, error_name + self.ser, rng, ifm, filter, accum_dtype, error_name ) # Invalidate Input/Output list for error if checks. @@ -1244,7 +1184,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1278,13 +1218,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_matmul( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 2 a, b = inputs accum_dtype = args_dict["acc_type"] result_tensor = OutputShaper.matmulOp( - self.ser, self.rng, a, b, accum_dtype, error_name + self.ser, rng, a, b, accum_dtype, error_name ) # Invalidate Input/Output list for error if checks. @@ -1293,7 +1240,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1328,12 +1275,12 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_reduce( - self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None + self, rng, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None ): assert len(inputs) == 1 a = inputs[0] axis = args_dict["axis"] - result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name) + result_tensor = OutputShaper.reduceOp(self.ser, rng, a, axis, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -1341,7 +1288,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1377,19 +1324,26 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_clamp( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 a = inputs[0] - result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) + result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name) - v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)] + v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)] if error_name == ErrorIf.MaxSmallerMin: # Make sure the numbers are different to invoke this error while v[0] == v[1]: - v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)] + v = [rng.randNumberDType(a.dtype), rng.randNumberDType(a.dtype)] max_val = min(v) min_val = max(v) else: @@ -1402,7 +1356,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1449,29 +1403,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) - def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) - attr = ts.TosaSerializerAttribute() - - attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32)) - - self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr) - return result_tens - - # Needs an additional type/input - def build_prelu(self, op, a, validator_fcns=None, error_name=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) - - self.ser.addOperator(op["op"], [a.name], [result_tens.name]) - return result_tens - def build_activation( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 a = inputs[0] - result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) + result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -1479,7 +1424,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1507,7 +1452,14 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_concat( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): if op["op"] == Op.CONCAT_SHAPE: axis = 0 @@ -1517,7 +1469,7 @@ class TosaTestGen: assert type(axis) == int result_tensor = OutputShaper.concatOp( - self.ser, self.rng, axis, inputs, error_name=error_name + self.ser, rng, axis, inputs, error_name=error_name ) input_tensor_names = [] @@ -1530,7 +1482,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1567,6 +1519,7 @@ class TosaTestGen: def build_pad( self, + rng, op, inputs, args_dict, @@ -1581,7 +1534,7 @@ class TosaTestGen: pad_const_int = args_dict["pad_const_int"] pad_const_float = args_dict["pad_const_fp"] - result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name) + result_tensor = OutputShaper.padOp(self.ser, rng, a, padding, error_name) # get pad_const_val_as_bytes from either pad_const_float or pad_const_int if gtu.dtypeIsFloat(a.dtype): @@ -1598,7 +1551,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1630,6 +1583,7 @@ class TosaTestGen: def build_dim( self, + rng, op, inputs, args_dict, @@ -1640,7 +1594,7 @@ class TosaTestGen: assert len(inputs) == 1 a = inputs[0] axis = args_dict["axis"] - result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name) + result_tensor = OutputShaper.dimOp(self.ser, rng, a, axis, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -1648,7 +1602,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1675,15 +1629,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, None) def build_reshape( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 2 a = inputs[0] shape = inputs[1] shape_attr = args_dict["new_shape"] - result_tensor = OutputShaper.reshapeOp( - self.ser, self.rng, a, shape_attr, error_name - ) + result_tensor = OutputShaper.reshapeOp(self.ser, rng, a, shape_attr, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name, shape.name] @@ -1691,7 +1650,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1719,12 +1678,19 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_reverse( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 a = inputs[0] axis = args_dict["axis"] - result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) + result_tensor = OutputShaper.unaryOp(self.ser, rng, a, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name] @@ -1732,7 +1698,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1759,15 +1725,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, None) def build_transpose( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 a = inputs[0] perms = args_dict["perms"] - result_tensor = OutputShaper.transposeOp( - self.ser, self.rng, a, perms, error_name - ) + result_tensor = OutputShaper.transposeOp(self.ser, rng, a, perms, error_name) attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(perms) @@ -1778,7 +1749,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1808,7 +1779,14 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_slice( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 3 a, start_var, size_var = inputs @@ -1816,7 +1794,7 @@ class TosaTestGen: size_const = args_dict["size"] result_tensor = OutputShaper.sliceOp( - self.ser, self.rng, a, start_const, size_const, error_name + self.ser, rng, a, start_const, size_const, error_name ) # Invalidate Input/Output list for error if checks. @@ -1825,7 +1803,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1856,14 +1834,21 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_tile( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 2 a = inputs[0] multiples = inputs[1] multiples_attr = args_dict["multiples"] result_tensor = OutputShaper.tileOp( - self.ser, self.rng, a, multiples_attr, error_name + self.ser, rng, a, multiples_attr, error_name ) # Invalidate Input/Output list for error if checks. @@ -1872,7 +1857,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1901,13 +1886,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_gather( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 2 values, indices = inputs result_tensor = OutputShaper.gatherOp( - self.ser, self.rng, values, indices, error_name + self.ser, rng, values, indices, error_name ) # Invalidate Input/Output list for error if checks. @@ -1916,7 +1908,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1944,12 +1936,19 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_scatter( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 3 values_in, indices, input = inputs result_tensor = OutputShaper.scatterOp( - self.ser, self.rng, values_in, indices, input, error_name + self.ser, rng, values_in, indices, input, error_name ) # Invalidate Input/Output list for error if checks. @@ -1958,7 +1957,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -1987,6 +1986,7 @@ class TosaTestGen: def build_resize( self, + rng, op, inputs, args_dict, @@ -2008,7 +2008,7 @@ class TosaTestGen: result_tensor = OutputShaper.resizeOp( self.ser, - self.rng, + rng, input, mode, scale, @@ -2030,7 +2030,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -2064,16 +2064,15 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) - def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None): - result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name) - result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name) - self.ser.addOperator( - op, [val.name, val2.name], [result_tens.name, result_tens2.name] - ) - return result_tens - def build_const( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 val = inputs[0] @@ -2087,14 +2086,21 @@ class TosaTestGen: # Type Conversion def build_cast( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 val = inputs[0] out_dtype = args_dict["out_type"] result_tensor = OutputShaper.typeConversionOp( - self.ser, self.rng, val, out_dtype, error_name + self.ser, rng, val, out_dtype, error_name ) # Invalidate Input/Output list for error if checks. @@ -2103,7 +2109,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) if not TosaErrorValidator.evValidateErrorIfs( @@ -2132,6 +2138,7 @@ class TosaTestGen: def build_rescale( self, + rng, op, inputs, args_dict, @@ -2151,7 +2158,7 @@ class TosaTestGen: multiplier_arr = args_dict["multiplier"] result_tensor = OutputShaper.typeConversionOp( - self.ser, self.rng, val, out_dtype, error_name + self.ser, rng, val, out_dtype, error_name ) if per_channel: @@ -2166,46 +2173,46 @@ class TosaTestGen: output_unsigned = False if val.dtype == DType.INT8: - input_zp = self.randInt(-128, 128) + input_zp = rng.randInt(-128, 128) in_type_width += 1 elif val.dtype == DType.UINT8: - input_zp = self.randInt(0, 256) + input_zp = rng.randInt(0, 256) in_type_width += 1 input_unsigned = True elif error_name in [ ErrorIf.InputZeroPointNotZero, ErrorIf.U16InputZeroPointNotValid, ]: - input_zp = self.randInt(-128, 128) + input_zp = rng.randInt(-128, 128) if input_zp == 0: - input_zp = input_zp + self.rng.integers(1, 10) + input_zp = input_zp + rng.integers(1, 10) in_type_width += 1 elif val.dtype == DType.UINT16: # Must come after ErrorIf.U16InputZeroPointNotValid check - input_zp = self.rng.choice([0, 32768]) + input_zp = rng.choice([0, 32768]) in_type_width += 1 input_unsigned = True else: input_zp = 0 if out_dtype == DType.INT8: - output_zp = self.randInt(-128, 128) + output_zp = rng.randInt(-128, 128) out_type_width += 1 elif out_dtype == DType.UINT8: - output_zp = self.randInt(0, 256) + output_zp = rng.randInt(0, 256) out_type_width += 1 output_unsigned = True elif error_name in [ ErrorIf.OutputZeroPointNotZero, ErrorIf.U16OutputZeroPointNotValid, ]: - output_zp = self.randInt(-128, 128) + output_zp = rng.randInt(-128, 128) if output_zp == 0: - output_zp = output_zp + self.rng.integers(1, 10) + output_zp = output_zp + rng.integers(1, 10) out_type_width += 1 elif out_dtype == DType.UINT16: # Must come after ErrorIf.U16OutputZeroPointNotValid check - output_zp = self.rng.choice([0, 32768]) + output_zp = rng.choice([0, 32768]) out_type_width += 1 output_unsigned = True else: @@ -2255,7 +2262,7 @@ class TosaTestGen: pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_list, output_list + rng, error_name, input_list, output_list ) qinfo = (input_zp, output_zp) @@ -2296,13 +2303,13 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) - def _get_condition_tensor(self, op, cond, error_name): + def _get_condition_tensor(self, rng, op, cond, error_name): if error_name == ErrorIf.CondIfCondNotMatchingBool: - cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL) + cond_type = gtu.get_wrong_output_type(op, rng, DType.BOOL) else: cond_type = DType.BOOL if error_name == ErrorIf.CondIfCondShapeNotSizeOne: - choice = self.rng.choice([1, 2]) + choice = rng.choice([1, 2]) if choice == 1: cond_shape = [2] else: @@ -2315,6 +2322,7 @@ class TosaTestGen: def build_cond_if_const( self, + rng, op, inputs, args_dict, @@ -2331,7 +2339,7 @@ class TosaTestGen: cond = args_dict["condition"] # Condition tensor - cond_tens = self._get_condition_tensor(op, cond, error_name) + cond_tens = self._get_condition_tensor(rng, op, cond, error_name) # Make then/else tensors out_shape = then_tens.shape @@ -2346,14 +2354,14 @@ class TosaTestGen: incorrect_shape = deepcopy(then_tens.shape) for i in range(len(incorrect_shape)): incorrect_shape[i] += ( - self.rng.choice([-3, -2, 2, 3]) + rng.choice([-3, -2, 2, 3]) if incorrect_shape[i] > 3 - else self.rng.choice([1, 2, 4]) + else rng.choice([1, 2, 4]) ) - incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape)) + incorrect_arr = np.int32(rng.integers(0, 256, size=incorrect_shape)) - then_arr = np.int32(self.rng.integers(0, 256, size=out_shape)) - else_arr = np.int32(self.rng.integers(0, 256, size=out_shape)) + then_arr = np.int32(rng.integers(0, 256, size=out_shape)) + else_arr = np.int32(rng.integers(0, 256, size=out_shape)) # And the result tensor based on any of the outputs result_tensor = self.ser.addOutput(out_shape, dtype) @@ -2400,6 +2408,7 @@ class TosaTestGen: def build_cond_if_binary( self, + rng, op, inputs, args_dict, @@ -2415,7 +2424,7 @@ class TosaTestGen: cond = args_dict["condition"] # Condition tensor - cond_tens = self._get_condition_tensor(op, cond, error_name) + cond_tens = self._get_condition_tensor(rng, op, cond, error_name) result_tensor = self.ser.addOutput(a.shape, a.dtype) @@ -2433,7 +2442,7 @@ class TosaTestGen: ]: incorrect_shape = a.shape.copy() for i in range(len(incorrect_shape)): - incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3]) + incorrect_shape[i] += rng.choice([-3, -2, 2, 3]) incorrect_block_input = deepcopy(a) incorrect_block_input.shape = incorrect_shape @@ -2503,7 +2512,14 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) def build_while_loop( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 1 a = inputs[0] @@ -2528,7 +2544,7 @@ class TosaTestGen: if error_name == ErrorIf.InputListOutputListMismatch: incorrect_acc = deepcopy(acc) for i in range(len(incorrect_acc.shape)): - incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3]) + incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3]) acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype) else: acc_out = self.ser.addIntermediate(acc.shape, acc.dtype) @@ -2549,13 +2565,13 @@ class TosaTestGen: ]: incorrect_iter = deepcopy(iter) for i in range(len(incorrect_iter.shape)): - incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3]) + incorrect_iter.shape[i] += rng.choice([-3, -2, 2, 3]) if len(incorrect_iter.shape) == 0: - incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3])) + incorrect_iter.shape.append(rng.choice([-3, -2, 2, 3])) incorrect_acc = deepcopy(acc) for i in range(len(incorrect_acc.shape)): - incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3]) + incorrect_acc.shape[i] += rng.choice([-3, -2, 2, 3]) # COND block (input: iter, output: cond_tens ) self.ser.addBasicBlock(cond_block) @@ -2571,11 +2587,11 @@ class TosaTestGen: zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)]) if error_name == ErrorIf.CondGraphOutputNotMatchingBool: - cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32]) + cond_type = rng.choice([DType.INT8, DType.INT32, DType.FP32]) else: cond_type = DType.BOOL if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne: - choice = self.rng.choice([1, 2]) + choice = rng.choice([1, 2]) if choice == 1: cond_shape = [3] else: @@ -2635,6 +2651,7 @@ class TosaTestGen: def build_fft2d( self, + rng, op, inputs, args_dict, @@ -2646,7 +2663,7 @@ class TosaTestGen: val1, val2 = inputs inverse = args_dict["inverse"] - results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name) + results = OutputShaper.fft2dOp(self.ser, rng, val1, val2, error_name) input_names = [val1.name, val2.name] pCount, cCount = op["operands"] @@ -2657,7 +2674,7 @@ class TosaTestGen: output_dtypes = [res.dtype for res in results] input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_names, output_names + rng, error_name, input_names, output_names ) if not TosaErrorValidator.evValidateErrorIfs( @@ -2699,6 +2716,7 @@ class TosaTestGen: def build_rfft2d( self, + rng, op, inputs, args_dict, @@ -2708,7 +2726,7 @@ class TosaTestGen: ): assert len(inputs) == 1 val = inputs[0] - results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) + results = OutputShaper.rfft2dOp(self.ser, rng, val, error_name) input_names = [val.name] pCount, cCount = op["operands"] @@ -2719,7 +2737,7 @@ class TosaTestGen: output_dtypes = [res.dtype for res in results] input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( - self, error_name, input_names, output_names + rng, error_name, input_names, output_names ) if not TosaErrorValidator.evValidateErrorIfs( @@ -2755,12 +2773,19 @@ class TosaTestGen: return TosaTestGen.BuildInfo(results, compliance) def build_shape_op( - self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + self, + rng, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(inputs) == 2 a, b = inputs - result_tensor = OutputShaper.addShapeOp(self.ser, self.rng, a, b, error_name) + result_tensor = OutputShaper.addShapeOp(self.ser, rng, a, b, error_name) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] @@ -2895,8 +2920,9 @@ class TosaTestGen: except KeyError: raise Exception("Cannot find op with name {}".format(opName)) - # Initialize a new random number generator - self.rng = np.random.default_rng(self.random_seed) + if not self.args.stable_rng: + # Initialize a new random number generator per op + self.resetGlobalRNG() _, tgen_fcn, _, agen_fcn = op["build_fcn"] @@ -2933,37 +2959,53 @@ class TosaTestGen: if shape is not None and len(shape) != r: continue self.setTargetShape(shape) - shapeList = tgen_fcn(self, op, r, error_name) + typeStr = self.typeStr(t) + if self.args.stable_rng: + shape_rng = TosaHashRandomGenerator( + self.random_seed, + [opName, r, typeStr], + self.random_dtype_range, + ) + else: + shape_rng = self.global_rng + shapeList = tgen_fcn(self, shape_rng, op, r, error_name) shapeStr = self.shapeStr(shapeList[0]) - typeStr = self.typeStr(t) # Argument lists consists of tuples of the (str, []) string representation and the build function argument list argList = [] if agen_fcn: - argList = agen_fcn(self, opName, shapeList, t, error_name) + if self.args.stable_rng: + arg_rng = TosaHashRandomGenerator( + self.random_seed, + [opName, shapeStr, typeStr], + self.random_dtype_range, + ) + else: + arg_rng = self.global_rng + + argList = agen_fcn( + self, arg_rng, opName, shapeList, t, error_name + ) else: argList = [("", [])] for argStr, args in argList: + # Create the test name string - for example: add_1x2x3_i32 if testType == "positive": - if argStr: - testStr = "{}_{}_{}_{}".format( - opName, shapeStr, typeStr, argStr - ) - else: - testStr = "{}_{}_{}".format( - opName, shapeStr, typeStr - ) - elif testType == "negative": - if argStr: - testStr = "{}_ERRORIF_{}_{}_{}_{}".format( - opName, error_name, shapeStr, typeStr, argStr - ) - else: - testStr = "{}_ERRORIF_{}_{}_{}".format( - opName, error_name, shapeStr, typeStr - ) + name_parts = [opName, shapeStr, typeStr] + else: + assert testType == "negative" + name_parts = [ + opName, + "ERRORIF", + error_name, + shapeStr, + typeStr, + ] + if argStr: + name_parts.append(argStr) + testStr = "_".join(name_parts) testList.append( (opName, testStr, t, error_name, shapeList, args) @@ -3038,8 +3080,18 @@ class TosaTestGen: # Build the random tensor operands and the test + # Set the random number generator + if self.args.stable_rng: + build_rng = TosaHashRandomGenerator( + self.random_seed, [testStr], self.random_dtype_range + ) + else: + build_rng = self.global_rng + if qgen is not None: - qinfo = qgen(self, op, dtype_or_dtypeList, error_name) + qinfo = qgen( + build_rng, self.args.zeropoint, op, dtype_or_dtypeList, error_name + ) else: qinfo = None @@ -3053,13 +3105,16 @@ class TosaTestGen: # New interface with args info in dictionary assert "dg_type" in argsDict - tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name) + tvgInfo = tvgen_fcn( + self, build_rng, opName, dtypeList, shapeList, argsDict, error_name + ) if tvgInfo.dataGenDict: tensMeta["data_gen"] = tvgInfo.dataGenDict tens = tvgInfo.tensorList result = build_fcn( self, + build_rng, op, tens, argsDict, |