diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-02-08 11:45:44 +0000 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-02-12 12:17:53 +0000 |
commit | 587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (patch) | |
tree | 2465605214e9723636bad2645508300421967052 /verif/generator/tosa_test_gen.py | |
parent | 01e1c1c7f965ceb07e78a3b1ad063161c0f47b94 (diff) | |
download | reference_model-587cc84c2b8c4b0d030b5e257c9a32461c0969b9.tar.gz |
Update test builder internal interfaces
Move remaining ops from using testArgs to argsDict.
All tvg/build_fcn function interfaces updated.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ie886fd931bd74608bda621363100bf8bfd7385e6
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 237 |
1 files changed, 139 insertions, 98 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 2d471c0..4ead982 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -519,15 +519,18 @@ class TosaTestGen: return result_tens def build_arithmetic_right_shift( - self, op, a, b, round, validator_fcns=None, error_name=None + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None ): - result_tens = OutputShaper.binaryBroadcastOp( + assert len(inputs) == 2 + a, b = inputs + round = args_dict["round"] + result_tensor = OutputShaper.binaryBroadcastOp( self.ser, self.rng, a, b, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -542,8 +545,8 @@ class TosaTestGen: input1=a, input2=b, input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -554,7 +557,12 @@ class TosaTestGen: attr.ArithmeticRightShiftAttribute(round) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_mul( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -612,15 +620,20 @@ class TosaTestGen: return TosaTestGen.BuildInfo(result_tensor, compliance) - def build_table(self, op, a, table, validator_fcns=None, error_name=None): - result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name) + def build_table( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + a = inputs[0] + table = args_dict["table"] + result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name) attr = ts.TosaSerializerAttribute() attr.TableAttribute(table) # Invalidate Input/Output list for error if checks. input_list = [a.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -634,8 +647,8 @@ class TosaTestGen: op=op, input_shape=a.shape, input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -644,7 +657,11 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_select( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -2075,15 +2092,20 @@ class TosaTestGen: def build_rescale( self, op, - val, - out_dtype, - scale32, - double_round, - per_channel, - validator_fcns, - error_name, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): - result_tens = OutputShaper.typeConversionOp( + assert len(inputs) == 1 + val = inputs[0] + out_dtype = args_dict["output_dtype"] + scale32 = args_dict["scale"] + double_round = args_dict["double_round"] + per_channel = args_dict["per_channel"] + + result_tensor = OutputShaper.typeConversionOp( self.ser, self.rng, val, out_dtype, error_name ) @@ -2203,7 +2225,7 @@ class TosaTestGen: # Invalidate Input/Output list for error if checks. input_list = [val.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -2224,7 +2246,7 @@ class TosaTestGen: double_round=double_round, input_list=input_list, output_list=output_list, - result_tensors=[result_tens], + result_tensors=[result_tensor], num_operands=num_operands, ): return None @@ -2243,7 +2265,12 @@ class TosaTestGen: ) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, val.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def _get_condition_tensor(self, op, cond, error_name): if error_name == ErrorIf.CondIfCondNotMatchingBool: @@ -2263,11 +2290,21 @@ class TosaTestGen: return cond_tens def build_cond_if_const( - self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None + self, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): # For cond_if with constants, we're supplied with then/else tensors that we ignore # (except for the generated shape) and the condition. Build Then/Else blocks # and fill them with const nodes for the body. + assert len(inputs) == 2 + then_tens, else_tens = inputs + + cond = args_dict["condition"] # Condition tensor cond_tens = self._get_condition_tensor(op, cond, error_name) @@ -2275,6 +2312,8 @@ class TosaTestGen: # Make then/else tensors out_shape = then_tens.shape + dtype = DType.INT32 + # Create an incorrect output shape for error_if tests if error_name in [ ErrorIf.CondIfOutputListThenGraphMismatch, @@ -2293,7 +2332,7 @@ class TosaTestGen: else_arr = np.int32(self.rng.integers(0, 256, size=out_shape)) # And the result tensor based on any of the outputs - result_tens = self.ser.addOutput(out_shape, DType.INT32) + result_tensor = self.ser.addOutput(out_shape, dtype) # Create the attribute with the names of the then/else blocks then_block = "THEN_BLOCK" @@ -2302,21 +2341,21 @@ class TosaTestGen: attr.CondIfAttribute(then_block, else_block) # Finally, build the op and the two blocks - self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr) + self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr) self.ser.addBasicBlock(then_block) # Build the actual then/else tensors inside their blocks if error_name == ErrorIf.CondIfOutputListThenGraphMismatch: - then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) + then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr) else: - then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr) + then_tens = self.ser.addConst(out_shape, dtype, then_arr) self.ser.addOutputTensor(then_tens) self.ser.addBasicBlock(else_block) if error_name == ErrorIf.CondIfOutputListElseGraphMismatch: - else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr) + else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr) else: - else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr) + else_tens = self.ser.addConst(out_shape, dtype, else_arr) self.ser.addOutputTensor(else_tens) if not TosaErrorValidator.evValidateErrorIfs( @@ -2329,18 +2368,32 @@ class TosaTestGen: ): return None - return result_tens + compliance = self.tensorComplianceMetaData( + op, dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_cond_if_binary( - self, op, a, b, cond, validator_fcns=None, error_name=None + self, + op, + inputs, + args_dict, + validator_fcns=None, + error_name=None, + qinfo=None, ): # For cond_if with a binary op in the then/else blocks, take a and b and # alternately add or subtract them based on the condition + assert len(inputs) == 2 + a, b = inputs + + cond = args_dict["condition"] # Condition tensor cond_tens = self._get_condition_tensor(op, cond, error_name) - result_tens = self.ser.addOutput(a.shape, a.dtype) + result_tensor = self.ser.addOutput(a.shape, a.dtype) # Create the attribute with the names of the then/else blocks then_block = "THEN_BLOCK" @@ -2362,17 +2415,24 @@ class TosaTestGen: # Finally, build the op and the two blocks self.ser.addOperator( - op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr + op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr ) if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32): - then_op, else_op = Op.ADD, Op.SUB + then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"] elif a.dtype in (DType.INT8, DType.INT16): - then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT + then_op, else_op = ( + self.TOSA_OP_LIST["logical_right_shift"], + self.TOSA_OP_LIST["logical_left_shift"], + ) else: assert False, f"No tests for DType: {a.dtype}" - for block, op in ((then_block, then_op), (else_block, else_op)): + # Determine the element-wise binary operation that compliance will need to + # check the results of + compliance_op = then_op if cond else else_op + + for block, block_op in ((then_block, then_op), (else_block, else_op)): self.ser.addBasicBlock(block) if ( error_name == ErrorIf.CondIfInputListThenGraphMismatch @@ -2398,7 +2458,7 @@ class TosaTestGen: self.ser.addInputTensor(a) self.ser.addInputTensor(b) tens = self.ser.addOutput(a.shape, a.dtype) - self.ser.addOperator(op, [a.name, b.name], [tens.name]) + self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name]) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -2412,9 +2472,19 @@ class TosaTestGen: ): return None - return result_tens + compliance = self.tensorComplianceMetaData( + compliance_op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) + + def build_while_loop( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + a = inputs[0] + iter_val = args_dict["iterations"] - def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None): iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)]) cond_block = "COND_BLOCK" @@ -2533,7 +2603,11 @@ class TosaTestGen: ): return None - return acc_out + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, acc_out, error_name + ) + + return TosaTestGen.BuildInfo(acc_out, compliance) def build_fft2d( self, @@ -2891,7 +2965,7 @@ class TosaTestGen: return testList def serializeTest( - self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs + self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict ): try: op = self.TOSA_OP_LIST[opName] @@ -2947,60 +3021,27 @@ class TosaTestGen: # Extra meta data for the desc.json tensMeta = {} - # Check we are using the new testArgs interface with an argsDict dictionary - if isinstance(testArgs, dict): - # New interface with args info in dictionary - argsDict = testArgs - assert "dg_type" in argsDict - tvgInfo = tvgen_fcn( - self, opName, dtypeList, shapeList, argsDict, error_name - ) - if tvgInfo.dataGenDict: - tensMeta["data_gen"] = tvgInfo.dataGenDict - tens = tvgInfo.tensorList - - result = build_fcn( - self, - op, - tens, - argsDict, - validator_fcns=error_if_validators, - error_name=error_name, - qinfo=qinfo, - ) - else: - # Old interface with args info in a list - tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name) - - try: - if error_if_validators is None: - if qinfo is not None: - result = build_fcn(self, op, *tens, *testArgs, qinfo) - else: - result = build_fcn(self, op, *tens, *testArgs) - else: - if qinfo is not None: - result = build_fcn( - self, - op, - *tens, - *testArgs, - validator_fcns=error_if_validators, - error_name=error_name, - qinfo=qinfo, - ) - else: - result = build_fcn( - self, - op, - *tens, - *testArgs, - validator_fcns=error_if_validators, - error_name=error_name, - ) - except TypeError as e: - print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n") - raise e + # Check we are using the new interface with an argsDict dictionary + assert isinstance( + argsDict, dict + ), f"{opName} is not using new tvg/build_fcn interface" + + # New interface with args info in dictionary + assert "dg_type" in argsDict + tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name) + if tvgInfo.dataGenDict: + tensMeta["data_gen"] = tvgInfo.dataGenDict + tens = tvgInfo.tensorList + + result = build_fcn( + self, + op, + tens, + argsDict, + validator_fcns=error_if_validators, + error_name=error_name, + qinfo=qinfo, + ) if result: # The test is valid, serialize it @@ -3874,7 +3915,7 @@ class TosaTestGen: "build_fcn": ( build_table, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agTable, ), "types": [DType.INT8, DType.INT16], @@ -4686,7 +4727,7 @@ class TosaTestGen: "build_fcn": ( build_rescale, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agRescale, ), "types": [ |