aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-09-24 14:42:13 +0100
committerMatthew Haddon <matthew.haddon@arm.com>2021-10-07 17:25:28 +0100
commiteacff9ae50b645ec9a293fd58082bacfdbe1e868 (patch)
treed1eea2fbd7ea584e82b962204d909ed905f32d19
parent693ba9ed076e3b9e95e484a27b087352d2bac157 (diff)
downloadreference_model-eacff9ae50b645ec9a293fd58082bacfdbe1e868.tar.gz
Add negative testing support to (most) EW Binary Ops
* Negative testing support for the following operators: ADD, BITWISE_AND, BITWISE_OR, BITWISE_XOR, INTDIV, LOGICAL_AND, LOGICAL_LEFT_SHIFT, LOGICAL_RIGHT_SHIFT, LOGICAL_OR, LOGICAL_XOR, MAXIMUM, MINIMUM, POW, SUB Change-Id: I2271f00b0b619604e864e36e4a4f987f1b2a37d4 Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
-rw-r--r--verif/tosa_error_if.py1
-rw-r--r--verif/tosa_test_gen.py148
2 files changed, 125 insertions, 24 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index 94648d3..c28591d 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -29,4 +29,5 @@ class ErrorIf(object):
WrongRank = "WrongRank"
BatchMismatch = "BatchMismatch"
ChannelMismatch = "ChannelMismatch"
+ RankMismatch = "RankMismatch"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 3cd1d69..f5f7fff 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -158,6 +158,12 @@ class TosaTensorGen:
for i in range(pl + const):
shape_list.append(shape.copy())
+ if error_name == ErrorIf.RankMismatch:
+ if rank == 1 and i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
+ elif i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
+
return shape_list
@staticmethod
@@ -221,6 +227,13 @@ class TosaTensorGen:
for i in range(pl + const):
shape_bcast = shape.copy()
+ if error_name == ErrorIf.RankMismatch:
+ bcast_idx = -1 # Turn off broadcast because we are not testing it
+ if rank == 1 and i != 1:
+ shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
+ elif i != 1:
+ shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
+
# If the chosen input, pick a random index to broadcast
if i == bcast_idx:
fuzz_idx = testGen.randInt(0, rank)
@@ -1224,16 +1237,21 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs['input_dtype']
output_dtype = kwargs['output_dtype']
- mode = kwargs['mode']
-
- if (
- (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
- (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
- (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
- (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
- (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
- ):
- error_result = True
+ op = kwargs['op']
+
+ if op['op'] == Op.RESIZE:
+ mode = kwargs['mode']
+ if (
+ (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
+ (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
+ (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
+ (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
+ (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ ):
+ error_result = True
+ else:
+ if output_dtype != input_dtype:
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -1605,6 +1623,29 @@ class TosaErrorValidator:
return info_dict
+ @staticmethod
+ def evRankMismatch(check=False, **kwargs):
+ error_name = ErrorIf.RankMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input Rank does not match output rank"
+
+ if check:
+ input1_shape = kwargs['input1'].shape
+ input2_shape = kwargs['input2'].shape
+ output_shape = kwargs['result_tensor'].shape
+ if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
class TosaInvalidValidator:
@staticmethod
@@ -1917,9 +1958,33 @@ class TosaTestGen:
self.ser.addOperator(op['op'], [a.name], [result_tens.name], None, qinfo)
return result_tens
- def build_binary_broadcast(self, op, a, b):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
- self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
+ def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
+
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name, b.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input1 = a,
+ input2 = b,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
def build_binary_nonbroadcast(self, op, a, b):
@@ -1928,7 +1993,7 @@ class TosaTestGen:
return result_tens
def build_arithmetic_right_shift(self, op, a, b, round):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
attr = ts.TosaSerializerAttribute()
attr.ArithmeticRightShiftAttribute(round)
@@ -1937,7 +2002,7 @@ class TosaTestGen:
return result_tens
def build_mul(self, op, a, b, shift):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
# Special for multiply:
# Force the result to INT32 for INT types
@@ -2716,7 +2781,7 @@ class TosaTestGen:
# Build the random tensor operands and the test
tens = []
- tens = self.generate_tensors(op, dtypeList, shapeList, testArgs)
+ tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
if qgen is not None:
qinfo = qgen(self, op, dtype_or_dtypeList)
@@ -2749,17 +2814,16 @@ class TosaTestGen:
self.serialize("test")
- def generate_tensors(self, op, dtypeList, shapeList, testArgs):
+ def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
tens = []
- if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32:
+ if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
# Make sure the operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
assert (
pCount == 2 and cCount == 0
), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
-
placeholders = []
add = (op["op"] == Op.ADD)
a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
@@ -2840,7 +2904,7 @@ class TosaTestGen:
self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
)
tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
- elif op["op"] == Op.INTDIV:
+ elif op["op"] == Op.INTDIV and error_name == None:
assert (
pCount == 2 and cCount == 0
), "Op.INTDIV must have 2 placeholders, 0 consts"
@@ -3186,6 +3250,8 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"arithmetic_right_shift": {
"op": Op.ARITHMETIC_RIGHT_SHIFT,
@@ -3202,66 +3268,88 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"bitwise_or": {
"op": Op.BITWISE_OR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"bitwise_xor": {
"op": Op.BITWISE_XOR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"intdiv": {
"op": Op.INTDIV,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": [DType.INT32],
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"logical_and": {
"op": Op.LOGICAL_AND,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"logical_left_shift": {
"op": Op.LOGICAL_LEFT_SHIFT,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"logical_right_shift": {
"op": Op.LOGICAL_RIGHT_SHIFT,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"logical_or": {
"op": Op.LOGICAL_OR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"logical_xor": {
"op": Op.LOGICAL_XOR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"maximum": {
"op": Op.MAXIMUM,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"minimum": {
"op": Op.MINIMUM,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"mul": {
"op": Op.MUL,
@@ -3274,12 +3362,16 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"sub": {
"op": Op.SUB,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"table": {
"op": Op.TABLE,
@@ -3577,18 +3669,26 @@ class OutputShaper:
# These methods return arguments that can be used for
# creating a new output tensor
@staticmethod
- def binaryBroadcastOp(ser, a, b):
- assert len(a.shape) == len(b.shape)
+ def binaryBroadcastOp(ser, rng, a, b, error_name=None):
+ if error_name != ErrorIf.RankMismatch:
+ assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
shape = []
for i in range(len(a.shape)):
- if a.shape[i] == 1:
+ if a.shape[i] == 1 and error_name == None:
shape.append(b.shape[i])
else:
shape.append(a.shape[i])
- return ser.addOutput(shape, a.dtype)
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = a.dtype
+
+ return ser.addOutput(shape, outputDType)
@staticmethod
def binaryNonBroadcastOp(ser, a, b):