aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2021-11-08 18:10:51 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2021-11-11 08:43:51 +0000
commit7e9ac9ab74ff2a793e226abf86d2543d1421d3c9 (patch)
treec1fa80639d5d21eefdd1cd13138d1fb91c7222b5
parent1c3c847a4368817e2c9e3af66d5deb4c67993cbc (diff)
downloadreference_model-7e9ac9ab74ff2a793e226abf86d2543d1421d3c9.tar.gz
Add Broadcast DimensionMismatch errors
Add RankMismatch and DimensionMismatch support for SELECT Update RankMismatch ops to also support DimensionMismatch Update POW op to have proper broadcast testing A few other broadcastable ops missing Rank/Dimension testing Change-Id: I6566f45a7a0db4f9f008456ea7a8e23d4192f4f9 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
-rw-r--r--verif/tosa_error_if.py1
-rw-r--r--verif/tosa_test_gen.py132
2 files changed, 92 insertions, 41 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index c3a9068..eb67ea8 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -30,6 +30,7 @@ class ErrorIf(object):
BatchMismatch = "BatchMismatch"
ChannelMismatch = "ChannelMismatch"
RankMismatch = "RankMismatch"
+ DimensionMismatch = "DimensionMismatch"
InputZeroPointNotZero = "InputZeroPointNotZero"
WeightZeroPointNotZero = "WeightZeroPointNotZero"
OutputZeroPointNotZero = "OutputZeroPointNotZero"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 80ccff3..db44328 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -270,21 +270,26 @@ class TosaTensorGen:
shape_list = []
# Choose one of the inputs to broadcast
- bcast_idx = testGen.randInt(0, pl + const)
+ # Note: Simplifies OutputShaper code if we don't change first shape for errors
+ bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const)
for i in range(pl + const):
shape_bcast = shape.copy()
- if error_name == ErrorIf.RankMismatch:
- bcast_idx = -1 # Turn off broadcast because we are not testing it
- if rank == 1 and i != 1:
- shape_bcast = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
- elif i != 1:
- shape_bcast = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
-
# If the chosen input, pick a random index to broadcast
if i == bcast_idx:
fuzz_idx = testGen.randInt(0, rank)
- shape_bcast[fuzz_idx] = 1
+ if error_name == ErrorIf.DimensionMismatch:
+ shape_bcast[fuzz_idx] += 1
+ elif error_name == ErrorIf.RankMismatch:
+ # Add one rank to the shape (or more for rank of 1)
+ extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
+ shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks)))
+ if rank != 1:
+ # Either keep the extra rank, or remove it
+ new_len = testGen.rng.choice([-2, len(shape_bcast)])
+ shape_bcast = shape_bcast[:new_len]
+ else:
+ shape_bcast[fuzz_idx] = 1
shape_list.append(shape_bcast)
@@ -2001,8 +2006,14 @@ class TosaErrorValidator:
if check:
input1_shape = kwargs['input1'].shape
input2_shape = kwargs['input2'].shape
+ # In case of SELECT op
+ input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
output_shape = kwargs['result_tensor'].shape
- if (len(input1_shape) != len(output_shape)) or (len(input2_shape) != len(output_shape)):
+ if (
+ (len(input1_shape) != len(output_shape)) or
+ (len(input2_shape) != len(output_shape)) or
+ (len(input3_shape) != len(output_shape))
+ ):
error_result = True
info_dict = {
@@ -2014,6 +2025,35 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evDimensionMismatch(check=False, **kwargs):
+ error_name = ErrorIf.DimensionMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input Dimensions do not match output"
+
+ if check:
+ input1_shape = kwargs['input1'].shape
+ input2_shape = kwargs['input2'].shape
+ # In case of SELECT op
+ input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
+ output_shape = kwargs['result_tensor'].shape
+ for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))):
+ if (
+ (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or
+ (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or
+ (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
op = kwargs['op']
inputDtypes = op['types'].copy()
@@ -3492,6 +3532,9 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
+ input1 = cond,
+ input2 = a,
+ input3 = b,
input_shape = a.shape,
input_dtype = a.dtype,
output_dtype = result_tens.dtype,
@@ -3519,6 +3562,8 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
+ input1 = a,
+ input2 = b,
input_shape = a.shape,
input_dtype = a.dtype,
output_shape = result_tens.shape,
@@ -5019,7 +5064,7 @@ class TosaTestGen:
)
tens.extend(placeholders)
- elif op["op"] == Op.MUL:
+ elif op["op"] == Op.MUL and error_name == None:
assert (
pCount == 2 and cCount == 0
), "Op.MUL must have 2 placeholders, 0 consts"
@@ -5363,7 +5408,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"arithmetic_right_shift": {
"op": Op.ARITHMETIC_RIGHT_SHIFT,
@@ -5374,8 +5419,8 @@ class TosaTestGen:
TosaArgGen.agArithmeticRightShift,
),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"bitwise_and": {
"op": Op.BITWISE_AND,
@@ -5383,7 +5428,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"bitwise_or": {
"op": Op.BITWISE_OR,
@@ -5391,7 +5436,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"bitwise_xor": {
"op": Op.BITWISE_XOR,
@@ -5399,7 +5444,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"intdiv": {
"op": Op.INTDIV,
@@ -5407,7 +5452,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": [DType.INT32],
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"logical_and": {
"op": Op.LOGICAL_AND,
@@ -5415,7 +5460,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"logical_left_shift": {
"op": Op.LOGICAL_LEFT_SHIFT,
@@ -5423,7 +5468,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"logical_right_shift": {
"op": Op.LOGICAL_RIGHT_SHIFT,
@@ -5431,7 +5476,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"logical_or": {
"op": Op.LOGICAL_OR,
@@ -5439,7 +5484,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"logical_xor": {
"op": Op.LOGICAL_XOR,
@@ -5447,7 +5492,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"maximum": {
"op": Op.MAXIMUM,
@@ -5455,7 +5500,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"minimum": {
"op": Op.MINIMUM,
@@ -5463,7 +5508,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"mul": {
"op": Op.MUL,
@@ -5471,15 +5516,15 @@ class TosaTestGen:
"build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
"types": TYPE_INT_FP,
"error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch)
},
"pow": {
"op": Op.POW,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBasic, None),
+ "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FP,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"sub": {
"op": Op.SUB,
@@ -5487,7 +5532,7 @@ class TosaTestGen:
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
"error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"table": {
"op": Op.TABLE,
@@ -5597,8 +5642,8 @@ class TosaTestGen:
"operands": (3, 0),
"build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
# Comparison operators
"equal": {
@@ -5606,24 +5651,24 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"greater_equal": {
"op": Op.GREATER_EQUAL,
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
"greater": {
"op": Op.GREATER,
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
},
# Reduction operators
"reduce_all": {
@@ -5916,12 +5961,16 @@ class OutputShaper:
@staticmethod
def selectOp(ser, rng, cond, a, b, error_name=None):
- assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
+ if error_name != ErrorIf.RankMismatch:
+ assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
assert a.dtype == b.dtype
shape = []
- for i in range(len(a.shape)):
- shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
+ for i in range(len(cond.shape)):
+ if cond.shape[i] == 1 and error_name == None:
+ shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
+ else:
+ shape.append(cond.shape[i])
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
@@ -5934,7 +5983,8 @@ class OutputShaper:
@staticmethod
def binaryComparisonOp(ser, rng, a, b , error_name=None):
- assert len(a.shape) == len(b.shape)
+ if error_name != ErrorIf.RankMismatch:
+ assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
# Do broadcast