From 729b035e447dda1f16df1aaee64b47c394e24f96 Mon Sep 17 00:00:00 2001 From: Les Bell Date: Wed, 24 Nov 2021 10:28:21 +0000 Subject: Do not generate tests that fail validation checks Change-Id: I33237ebfd946b9ec91352c2b0dc6298cc113cd77 Signed-off-by: Les Bell --- verif/tosa_test_gen.py | 198 ++++++++++++++++++++++++++++++------------------- 1 file changed, 123 insertions(+), 75 deletions(-) diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 22886d6..655cdfc 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1565,7 +1565,17 @@ class TosaErrorValidator: @staticmethod def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs): - # Check ERROR_IF statements + """Check ERROR_IF statements are caught and set the expected result. + + Args: + serializer: the serializer to set the expected result in + validator_fcns: a sequence of validator functions to verify the result + error_name: the name of the ERROR_IF condition to check for + kwargs: keyword arguments for the validator functions + Returns: + True if the result matches the expected result; otherwise False + """ + overall_result = True for val_fcn in validator_fcns: val_result = val_fcn(True, **kwargs) validator_name = val_result['error_name'] @@ -1574,6 +1584,7 @@ class TosaErrorValidator: # expect an error IFF the error_name and validator_name match expected_result = error_result == (error_name == validator_name) + overall_result &= expected_result if expected_result and error_result: serializer.setExpectedReturnCode(2, error_reason) @@ -1591,6 +1602,8 @@ class TosaErrorValidator: v = valueToName(DType, v) print(f' {k} = {v}') + return overall_result + @staticmethod def evWrongInputType(check=False, **kwargs): error_result = False @@ -3447,7 +3460,7 @@ class TosaTestGen: elif t == DType.BOOL: return 1 else: - raise Exception("Unknown dtype, cannot convert to string: {}".format(t)) + raise Exception(f"Unknown dtype, cannot determine width: {t}") # Argument generators # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list]) @@ -3481,7 +3494,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3493,7 +3506,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list, None, qinfo) return result_tens @@ -3509,7 +3523,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3522,7 +3536,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list) return result_tens @@ -3542,7 +3557,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3555,7 +3570,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.ArithmeticRightShiftAttribute(round) @@ -3582,7 +3598,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3595,7 +3611,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift) @@ -3616,7 +3633,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3628,7 +3645,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list, attr) @@ -3644,7 +3662,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3659,7 +3677,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list,) return result_tens @@ -3674,7 +3693,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3689,7 +3708,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list,) return result_tens @@ -3704,7 +3724,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3718,7 +3738,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) @@ -3744,7 +3765,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3761,7 +3782,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.PoolAttribute(kernel, stride, pad) @@ -3788,7 +3810,7 @@ class TosaTestGen: num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3804,7 +3826,8 @@ class TosaTestGen: stride=strides, dilation=dilations, input_shape=ifm.shape, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) @@ -3833,7 +3856,7 @@ class TosaTestGen: num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3849,7 +3872,8 @@ class TosaTestGen: stride=strides, dilation=dilations, input_shape=ifm.shape, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) @@ -3878,7 +3902,7 @@ class TosaTestGen: num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3894,7 +3918,8 @@ class TosaTestGen: stride=stride, dilation=dilation, input_shape=ifm.shape, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute(outpad, stride, dilation, output_shape) @@ -3924,7 +3949,7 @@ class TosaTestGen: num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3940,7 +3965,8 @@ class TosaTestGen: stride=strides, dilation=dilations, input_shape=ifm.shape, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) @@ -3960,7 +3986,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -3975,7 +4001,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator( op['op'], input_list, output_list, None, qinfo @@ -3992,7 +4019,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4008,7 +4035,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list, None, qinfo) return result_tens @@ -4023,7 +4051,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4037,7 +4065,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) @@ -4067,7 +4096,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4082,7 +4111,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() if a.dtype == DType.FLOAT: @@ -4119,7 +4149,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4132,7 +4162,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list) return result_tens @@ -4147,7 +4178,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4160,7 +4191,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list) return result_tens @@ -4186,7 +4218,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4201,7 +4233,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) @@ -4223,7 +4256,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4238,7 +4271,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator( op['op'], input_list, output_list, attr, qinfo @@ -4255,7 +4289,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4268,7 +4302,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute(newShape) @@ -4286,7 +4321,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4300,7 +4335,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) @@ -4321,7 +4357,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4335,7 +4371,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list, attr) @@ -4351,7 +4388,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4366,7 +4403,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.SliceAttribute(start, size) @@ -4384,7 +4422,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4397,7 +4435,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.TileAttribute(multiples) @@ -4428,7 +4467,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4441,7 +4480,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list) @@ -4468,7 +4508,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4481,7 +4521,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list) @@ -4527,7 +4568,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4546,7 +4587,8 @@ class TosaTestGen: output_list=output_list, result_tensor=result_tens, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() @@ -4580,7 +4622,7 @@ class TosaTestGen: num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4593,7 +4635,8 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, - ) + ): + return None self.ser.addOperator(op['op'], input_list, output_list) return result_tens @@ -4671,7 +4714,7 @@ class TosaTestGen: input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) qinfo = (input_zp, output_zp) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4686,7 +4729,8 @@ class TosaTestGen: output_list=output_list, result_tensor=result_tens, num_operands=num_operands, - ) + ): + return None attr = ts.TosaSerializerAttribute() attr.RescaleAttribute( @@ -4750,13 +4794,14 @@ class TosaTestGen: else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr) self.ser.addOutputTensor(else_tens) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, basicBlocks=self.ser.basicBlocks - ) + ): + return None return result_tens @@ -4814,7 +4859,7 @@ class TosaTestGen: tens = self.ser.addOutput(a.shape, a.dtype) self.ser.addOperator(op, [a.name, b.name], [tens.name]) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, @@ -4822,7 +4867,8 @@ class TosaTestGen: a=a, b=b, basicBlocks=self.ser.basicBlocks - ) + ): + return None return result_tens @@ -4917,13 +4963,14 @@ class TosaTestGen: self.ser.addOutputTensor(a) self.ser.addOutputTensor(acc_body_out) - TosaErrorValidator.evValidateErrorIfs( + if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, basicBlocks=self.ser.basicBlocks - ) + ): + return None return acc_out @@ -5156,11 +5203,12 @@ class TosaTestGen: print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n") raise e - if resultName is None: - print("Invalid ERROR_IF tests created") - - # Save the serialized test - self.serialize("test") + if resultName: + # The test is valid, serialize it + self.serialize("test") + else: + # The test is not valid + print(f"Invalid ERROR_IF test created: {opName} {testStr}") def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None): -- cgit v1.2.1