aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLes Bell <les.bell@arm.com>2021-11-24 10:28:21 +0000
committerLes Bell <les.bell@arm.com>2021-11-24 10:30:03 +0000
commit729b035e447dda1f16df1aaee64b47c394e24f96 (patch)
tree082cd224ef9ab09f8096017ff8d87e32287d99fe
parent3ca02a7521b89deccbe4a2e851bb7d66fcfc93e8 (diff)
downloadreference_model-729b035e447dda1f16df1aaee64b47c394e24f96.tar.gz
Do not generate tests that fail validation checks
Change-Id: I33237ebfd946b9ec91352c2b0dc6298cc113cd77 Signed-off-by: Les Bell <les.bell@arm.com>
-rw-r--r--verif/tosa_test_gen.py198
1 files changed, 123 insertions, 75 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 22886d6..655cdfc 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1565,7 +1565,17 @@ class TosaErrorValidator:
@staticmethod
def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
- # Check ERROR_IF statements
+ """Check ERROR_IF statements are caught and set the expected result.
+
+ Args:
+ serializer: the serializer to set the expected result in
+ validator_fcns: a sequence of validator functions to verify the result
+ error_name: the name of the ERROR_IF condition to check for
+ kwargs: keyword arguments for the validator functions
+ Returns:
+ True if the result matches the expected result; otherwise False
+ """
+ overall_result = True
for val_fcn in validator_fcns:
val_result = val_fcn(True, **kwargs)
validator_name = val_result['error_name']
@@ -1574,6 +1584,7 @@ class TosaErrorValidator:
# expect an error IFF the error_name and validator_name match
expected_result = error_result == (error_name == validator_name)
+ overall_result &= expected_result
if expected_result and error_result:
serializer.setExpectedReturnCode(2, error_reason)
@@ -1591,6 +1602,8 @@ class TosaErrorValidator:
v = valueToName(DType, v)
print(f' {k} = {v}')
+ return overall_result
+
@staticmethod
def evWrongInputType(check=False, **kwargs):
error_result = False
@@ -3447,7 +3460,7 @@ class TosaTestGen:
elif t == DType.BOOL:
return 1
else:
- raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
+ raise Exception(f"Unknown dtype, cannot determine width: {t}")
# Argument generators
# Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
@@ -3481,7 +3494,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3493,7 +3506,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
return result_tens
@@ -3509,7 +3523,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3522,7 +3536,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
@@ -3542,7 +3557,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3555,7 +3570,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.ArithmeticRightShiftAttribute(round)
@@ -3582,7 +3598,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3595,7 +3611,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift)
@@ -3616,7 +3633,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3628,7 +3645,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list, attr)
@@ -3644,7 +3662,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3659,7 +3677,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list,)
return result_tens
@@ -3674,7 +3693,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3689,7 +3708,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list,)
return result_tens
@@ -3704,7 +3724,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3718,7 +3738,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
@@ -3744,7 +3765,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3761,7 +3782,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(kernel, stride, pad)
@@ -3788,7 +3810,7 @@ class TosaTestGen:
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3804,7 +3826,8 @@ class TosaTestGen:
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
@@ -3833,7 +3856,7 @@ class TosaTestGen:
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3849,7 +3872,8 @@ class TosaTestGen:
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
@@ -3878,7 +3902,7 @@ class TosaTestGen:
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3894,7 +3918,8 @@ class TosaTestGen:
stride=stride,
dilation=dilation,
input_shape=ifm.shape,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
@@ -3924,7 +3949,7 @@ class TosaTestGen:
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3940,7 +3965,8 @@ class TosaTestGen:
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
@@ -3960,7 +3986,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -3975,7 +4001,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(
op['op'], input_list, output_list, None, qinfo
@@ -3992,7 +4019,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4008,7 +4035,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
return result_tens
@@ -4023,7 +4051,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4037,7 +4065,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
@@ -4067,7 +4096,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4082,7 +4111,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
if a.dtype == DType.FLOAT:
@@ -4119,7 +4149,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4132,7 +4162,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
@@ -4147,7 +4178,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4160,7 +4191,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
@@ -4186,7 +4218,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4201,7 +4233,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
@@ -4223,7 +4256,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4238,7 +4271,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(
op['op'], input_list, output_list, attr, qinfo
@@ -4255,7 +4289,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4268,7 +4302,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.ReshapeAttribute(newShape)
@@ -4286,7 +4321,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4300,7 +4335,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
@@ -4321,7 +4357,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4335,7 +4371,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list, attr)
@@ -4351,7 +4388,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4366,7 +4403,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.SliceAttribute(start, size)
@@ -4384,7 +4422,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4397,7 +4435,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.TileAttribute(multiples)
@@ -4428,7 +4467,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4441,7 +4480,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list)
@@ -4468,7 +4508,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4481,7 +4521,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list)
@@ -4527,7 +4568,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4546,7 +4587,8 @@ class TosaTestGen:
output_list=output_list,
result_tensor=result_tens,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
@@ -4580,7 +4622,7 @@ class TosaTestGen:
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4593,7 +4635,8 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
- )
+ ):
+ return None
self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
@@ -4671,7 +4714,7 @@ class TosaTestGen:
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
qinfo = (input_zp, output_zp)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4686,7 +4729,8 @@ class TosaTestGen:
output_list=output_list,
result_tensor=result_tens,
num_operands=num_operands,
- )
+ ):
+ return None
attr = ts.TosaSerializerAttribute()
attr.RescaleAttribute(
@@ -4750,13 +4794,14 @@ class TosaTestGen:
else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
self.ser.addOutputTensor(else_tens)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
basicBlocks=self.ser.basicBlocks
- )
+ ):
+ return None
return result_tens
@@ -4814,7 +4859,7 @@ class TosaTestGen:
tens = self.ser.addOutput(a.shape, a.dtype)
self.ser.addOperator(op, [a.name, b.name], [tens.name])
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
@@ -4822,7 +4867,8 @@ class TosaTestGen:
a=a,
b=b,
basicBlocks=self.ser.basicBlocks
- )
+ ):
+ return None
return result_tens
@@ -4917,13 +4963,14 @@ class TosaTestGen:
self.ser.addOutputTensor(a)
self.ser.addOutputTensor(acc_body_out)
- TosaErrorValidator.evValidateErrorIfs(
+ if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
basicBlocks=self.ser.basicBlocks
- )
+ ):
+ return None
return acc_out
@@ -5156,11 +5203,12 @@ class TosaTestGen:
print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
raise e
- if resultName is None:
- print("Invalid ERROR_IF tests created")
-
- # Save the serialized test
- self.serialize("test")
+ if resultName:
+ # The test is valid, serialize it
+ self.serialize("test")
+ else:
+ # The test is not valid
+ print(f"Invalid ERROR_IF test created: {opName} {testStr}")
def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):