From 587cc84c2b8c4b0d030b5e257c9a32461c0969b9 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 8 Feb 2024 11:45:44 +0000 Subject: Update test builder internal interfaces Move remaining ops from using testArgs to argsDict. All tvg/build_fcn function interfaces updated. Signed-off-by: Jeremy Johnson Change-Id: Ie886fd931bd74608bda621363100bf8bfd7385e6 --- verif/generator/tosa_arg_gen.py | 117 ++++++++++----- verif/generator/tosa_test_gen.py | 237 ++++++++++++++++++------------ verif/generator/tosa_verif_build_tests.py | 4 +- 3 files changed, 220 insertions(+), 138 deletions(-) diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 33e74b5..7ec0cfe 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -636,18 +636,6 @@ class TosaTensorValuesGen: self.tensorList = tensorList self.dataGenDict = dataGenDict - @staticmethod - def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None): - pCount, cCount = op["operands"] - - tens = [] - tens.extend( - testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount]) - ) - tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:])) - - return tens - # Default high value for random numbers TVG_FLOAT_HIGH_VALUE = { DType.FP32: (1 << 128) - (1 << (127 - 23)), @@ -969,7 +957,7 @@ class TosaTensorValuesGen: @staticmethod def tvgCondIfWhileLoop( - testGen, op, dtypeList, shapeList, testArgs, error_name=None + testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): if dtypeList[0] in ( DType.INT32, @@ -979,9 +967,10 @@ class TosaTensorValuesGen: # Limit input tensors with cond_if_binary or while_loop to stop # saturation of add/sub ops with int32 and keep all logical shift # values between 0 to 31 for int16 or int8 + op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] pRemain = pCount - placeholders = [] + tens_ser_list = [] for idx, shape in enumerate(shapeList[:]): if dtypeList[0] == DType.INT32: arr = testGen.getRandTensor(shapeList[idx], DType.INT16) @@ -990,32 +979,33 @@ class TosaTensorValuesGen: testGen.rng.integers(low=0, high=32, size=shapeList[idx]) ) if pRemain > 0: - placeholders.append( + tens_ser_list.append( testGen.ser.addPlaceholder(shape, dtypeList[idx], arr) ) pRemain -= 1 else: - placeholders.append( + tens_ser_list.append( testGen.ser.addConst(shape, dtypeList[idx], arr) ) - return placeholders + return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) else: - return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, error_name + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod def tvgArithmeticRightShift( - testGen, op, dtypeList, shapeList, testArgs, error_name=None + testGen, opName, dtypeList, shapeList, argsDict, error_name=None ): + op = testGen.TOSA_OP_LIST[opName] pCount, cCount = op["operands"] # Force value of operand[1] to be within [0, num_bits] assert ( pCount == 2 and cCount == 0 ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts" - placeholders = [] + tens_ser_list = [] for idx, shape in enumerate(shapeList[:]): if idx == 1: if dtypeList[idx] == DType.INT8: @@ -1030,23 +1020,23 @@ class TosaTensorValuesGen: raise Exception("OpArithmeticRightShift: invalid input dtype") else: arr = testGen.getRandTensor(shape, dtypeList[idx]) - placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)) + tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)) - return placeholders + return TosaTensorValuesGen.TVGInfo(tens_ser_list, None) @staticmethod - def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgReshape(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["new_shape"])] # Create a new list for the pre-generated data in argsDict["fixed_data"] argsDict["fixed_data"] = [None, argsDict["new_shape"]] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod - def tvgPad(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgPad(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): # argsDict["pad"] is 2D array, need to flatten it to get list of values pad_values = argsDict["pad"].flatten() dtypeList[1] = DType.SHAPE @@ -1055,11 +1045,11 @@ class TosaTensorValuesGen: argsDict["fixed_data"] = [None, pad_values] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod - def tvgSlice(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgSlice(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["start"])] dtypeList[2] = DType.SHAPE @@ -1068,17 +1058,17 @@ class TosaTensorValuesGen: argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod - def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + def tvgTile(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): dtypeList[1] = DType.SHAPE shapeList[1] = [len(argsDict["multiples"])] argsDict["fixed_data"] = [None, argsDict["multiples"]] return TosaTensorValuesGen.tvgLazyGenDefault( - testGen, op, dtypeList, shapeList, argsDict, error_name + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod @@ -2776,10 +2766,23 @@ class TosaArgGen: int(double_round), int(per_channel), ), - [outDtype, scale32, double_round, per_channel], + { + "output_dtype": outDtype, + "scale": scale32, + "double_round": double_round, + "per_channel": per_channel, + }, ) ) + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + inDtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod @@ -2808,9 +2811,20 @@ class TosaArgGen: def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None): arg_list = [] - arg_list.append(("roundTrue", [True])) - arg_list.append(("roundFalse", [False])) + for round in (True, False): + args_dict = { + "round": round, + } + arg_list.append((f"round{round}", args_dict)) + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod @@ -3414,9 +3428,18 @@ class TosaArgGen: arg_list.append( ( "", - [table], + {"table": table}, ) ) + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list def agCondIf(testGen, opName, shapeList, dtype, error_name=None): @@ -3426,15 +3449,33 @@ class TosaArgGen: arg_list = [] for c in [False, True]: - arg_list.append(("cond{}".format(int(c)), [c])) + arg_list.append(("cond{}".format(int(c)), {"condition": c})) + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None): # While loop: 0 iterations, 1, more than 1 arg_list = [] - for iter in [0, 1, 4]: - arg_list.append(("iter{}".format(iter), [iter])) + for iterations in [0, 1, 4]: + arg_list.append(("iter{}".format(iterations), {"iterations": iterations})) + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 2d471c0..4ead982 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -519,15 +519,18 @@ class TosaTestGen: return result_tens def build_arithmetic_right_shift( - self, op, a, b, round, validator_fcns=None, error_name=None + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None ): - result_tens = OutputShaper.binaryBroadcastOp( + assert len(inputs) == 2 + a, b = inputs + round = args_dict["round"] + result_tensor = OutputShaper.binaryBroadcastOp( self.ser, self.rng, a, b, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -542,8 +545,8 @@ class TosaTestGen: input1=a, input2=b, input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -554,7 +557,12 @@ class TosaTestGen: attr.ArithmeticRightShiftAttribute(round) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_mul( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -612,15 +620,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) - def build_table(self, op, a, table, validator_fcns=None, error_name=None): - result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name) + def build_table( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + a = inputs[0] + table = args_dict["table"] + result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name) attr = ts.TosaSerializerAttribute() attr.TableAttribute(table) # Invalidate Input/Output list for error if checks. input_list = [a.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -634,8 +647,8 @@ class TosaTestGen: op=op, input_shape=a.shape, input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -644,7 +657,11 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_select( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -2075,15 +2092,20 @@ class TosaTestGen: def build_rescale( self, op, - val, - out_dtype, - scale32, - double_round, - per_channel, - validator_fcns, - error_name, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): - result_tens = OutputShaper.typeConversionOp( + assert len(inputs) == 1 + val = inputs[0] + out_dtype = args_dict["output_dtype"] + scale32 = args_dict["scale"] + double_round = args_dict["double_round"] + per_channel = args_dict["per_channel"] + + result_tensor = OutputShaper.typeConversionOp( self.ser, self.rng, val, out_dtype, error_name ) @@ -2203,7 +2225,7 @@ class TosaTestGen: # Invalidate Input/Output list for error if checks. input_list = [val.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -2224,7 +2246,7 @@ class TosaTestGen: double_round=double_round, input_list=input_list, output_list=output_list, - result_tensors=[result_tens], + result_tensors=[result_tensor], num_operands=num_operands, ): return None @@ -2243,7 +2265,12 @@ class TosaTestGen: ) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, val.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def _get_condition_tensor(self, op, cond, error_name): if error_name == ErrorIf.CondIfCondNotMatchingBool: @@ -2263,11 +2290,21 @@ class TosaTestGen: return cond_tens def build_cond_if_const( - self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None + self, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): # For cond_if with constants, we're supplied with then/else tensors that we ignore # (except for the generated shape) and the condition. Build Then/Else blocks # and fill them with const nodes for the body. + assert len(inputs) == 2 + then_tens, else_tens = inputs + + cond = args_dict["condition"] # Condition tensor cond_tens = self._get_condition_tensor(op, cond, error_name) @@ -2275,6 +2312,8 @@ class TosaTestGen: # Make then/else tensors out_shape = then_tens.shape + dtype = DType.INT32 + # Create an incorrect output shape for error_if tests if error_name in [ ErrorIf.CondIfOutputListThenGraphMismatch, @@ -2293,7 +2332,7 @@ class TosaTestGen: else_arr = np.int32(self.rng.integers(0, 256, size=out_shape)) # And the result tensor based on any of the outputs - result_tens = self.ser.addOutput(out_shape, DType.INT32) + result_tensor = self.ser.addOutput(out_shape, dtype) # Create the attribute with the names of the then/else blocks then_block = "THEN_BLOCK" @@ -2302,21 +2341,21 @@ class TosaTestGen: attr.CondIfAttribute(then_block, else_block) # Finally, build the op and the two blocks - self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr) + self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr) self.ser.addBasicBlock(then_block) # Build the actual then/else tensors inside their blocks if error_name == ErrorIf.CondIfOutputListThenGraphMismatch: - then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) + then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr) else: - then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr) + then_tens = self.ser.addConst(out_shape, dtype, then_arr) self.ser.addOutputTensor(then_tens) self.ser.addBasicBlock(else_block) if error_name == ErrorIf.CondIfOutputListElseGraphMismatch: - else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) + else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr) else: - else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr) + else_tens = self.ser.addConst(out_shape, dtype, else_arr) self.ser.addOutputTensor(else_tens) if not TosaErrorValidator.evValidateErrorIfs( @@ -2329,18 +2368,32 @@ class TosaTestGen: ): return None - return result_tens + compliance = self.tensorComplianceMetaData( + op, dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_cond_if_binary( - self, op, a, b, cond, validator_fcns=None, error_name=None + self, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): # For cond_if with a binary op in the then/else blocks, take a and b and # alternately add or subtract them based on the condition + assert len(inputs) == 2 + a, b = inputs + + cond = args_dict["condition"] # Condition tensor cond_tens = self._get_condition_tensor(op, cond, error_name) - result_tens = self.ser.addOutput(a.shape, a.dtype) + result_tensor = self.ser.addOutput(a.shape, a.dtype) # Create the attribute with the names of the then/else blocks then_block = "THEN_BLOCK" @@ -2362,17 +2415,24 @@ class TosaTestGen: # Finally, build the op and the two blocks self.ser.addOperator( - op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr + op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr ) if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32): - then_op, else_op = Op.ADD, Op.SUB + then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"] elif a.dtype in (DType.INT8, DType.INT16): - then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT + then_op, else_op = ( + self.TOSA_OP_LIST["logical_right_shift"], + self.TOSA_OP_LIST["logical_left_shift"], + ) else: assert False, f"No tests for DType: {a.dtype}" - for block, op in ((then_block, then_op), (else_block, else_op)): + # Determine the element-wise binary operation that compliance will need to + # check the results of + compliance_op = then_op if cond else else_op + + for block, block_op in ((then_block, then_op), (else_block, else_op)): self.ser.addBasicBlock(block) if ( error_name == ErrorIf.CondIfInputListThenGraphMismatch @@ -2398,7 +2458,7 @@ class TosaTestGen: self.ser.addInputTensor(a) self.ser.addInputTensor(b) tens = self.ser.addOutput(a.shape, a.dtype) - self.ser.addOperator(op, [a.name, b.name], [tens.name]) + self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name]) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -2412,9 +2472,19 @@ class TosaTestGen: ): return None - return result_tens + compliance = self.tensorComplianceMetaData( + compliance_op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) + + def build_while_loop( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + a = inputs[0] + iter_val = args_dict["iterations"] - def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None): iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)]) cond_block = "COND_BLOCK" @@ -2533,7 +2603,11 @@ class TosaTestGen: ): return None - return acc_out + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, acc_out, error_name + ) + + return TosaTestGen.BuildInfo(acc_out, compliance) def build_fft2d( self, @@ -2891,7 +2965,7 @@ class TosaTestGen: return testList def serializeTest( - self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs + self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict ): try: op = self.TOSA_OP_LIST[opName] @@ -2947,60 +3021,27 @@ class TosaTestGen: # Extra meta data for the desc.json tensMeta = {} - # Check we are using the new testArgs interface with an argsDict dictionary - if isinstance(testArgs, dict): - # New interface with args info in dictionary - argsDict = testArgs - assert "dg_type" in argsDict - tvgInfo = tvgen_fcn( - self, opName, dtypeList, shapeList, argsDict, error_name - ) - if tvgInfo.dataGenDict: - tensMeta["data_gen"] = tvgInfo.dataGenDict - tens = tvgInfo.tensorList - - result = build_fcn( - self, - op, - tens, - argsDict, - validator_fcns=error_if_validators, - error_name=error_name, - qinfo=qinfo, - ) - else: - # Old interface with args info in a list - tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name) - - try: - if error_if_validators is None: - if qinfo is not None: - result = build_fcn(self, op, *tens, *testArgs, qinfo) - else: - result = build_fcn(self, op, *tens, *testArgs) - else: - if qinfo is not None: - result = build_fcn( - self, - op, - *tens, - *testArgs, - validator_fcns=error_if_validators, - error_name=error_name, - qinfo=qinfo, - ) - else: - result = build_fcn( - self, - op, - *tens, - *testArgs, - validator_fcns=error_if_validators, - error_name=error_name, - ) - except TypeError as e: - print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n") - raise e + # Check we are using the new interface with an argsDict dictionary + assert isinstance( + argsDict, dict + ), f"{opName} is not using new tvg/build_fcn interface" + + # New interface with args info in dictionary + assert "dg_type" in argsDict + tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name) + if tvgInfo.dataGenDict: + tensMeta["data_gen"] = tvgInfo.dataGenDict + tens = tvgInfo.tensorList + + result = build_fcn( + self, + op, + tens, + argsDict, + validator_fcns=error_if_validators, + error_name=error_name, + qinfo=qinfo, + ) if result: # The test is valid, serialize it @@ -3874,7 +3915,7 @@ class TosaTestGen: "build_fcn": ( build_table, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agTable, ), "types": [DType.INT8, DType.INT16], @@ -4686,7 +4727,7 @@ class TosaTestGen: "build_fcn": ( build_rescale, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agRescale, ), "types": [ diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py index d01e8a7..8012d93 100644 --- a/verif/generator/tosa_verif_build_tests.py +++ b/verif/generator/tosa_verif_build_tests.py @@ -321,7 +321,7 @@ def main(argv=None): testStrings = [] try: - for opName, testStr, dtype, error, shapeList, testArgs in testList: + for opName, testStr, dtype, error, shapeList, argsDict in testList: # Check for and skip duplicate tests if testStr in testStrings: print(f"Skipping duplicate test: {testStr}") @@ -331,7 +331,7 @@ def main(argv=None): results.append( ttg.serializeTest( - opName, testStr, dtype, error, shapeList, testArgs + opName, testStr, dtype, error, shapeList, argsDict ) ) except Exception as e: -- cgit v1.2.1