aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-09-22 11:24:37 +0100
committerMatthew Haddon <matthew.haddon@arm.com>2021-10-07 17:24:19 +0100
commit693ba9ed076e3b9e95e484a27b087352d2bac157 (patch)
treea32cd47cc4facc00bf821e6618e5cbbd5f51a833
parent1c00b71c44c80b2433e1837af204317636a69f95 (diff)
downloadreference_model-693ba9ed076e3b9e95e484a27b087352d2bac157.tar.gz
Add ERROR_IF checks for mismatched batch/channel
Change-Id: I7c670c5f9b97a18a6f586b16f31bc9fc301f6bc3 Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
-rw-r--r--verif/tosa_error_if.py2
-rw-r--r--verif/tosa_test_gen.py70
2 files changed, 68 insertions, 4 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index 58595d3..94648d3 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -27,4 +27,6 @@ class ErrorIf(object):
WrongInputList = "WrongInputList"
WrongOutputList = "WrongOutputList"
WrongRank = "WrongRank"
+ BatchMismatch = "BatchMismatch"
+ ChannelMismatch = "ChannelMismatch"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 2c13172..3cd1d69 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1347,6 +1347,59 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evBatchMismatch(check=False, **kwargs):
+ error_name = ErrorIf.BatchMismatch
+ param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input batch size not equal to output batch size"
+
+ assert 'op' in kwargs
+ op = kwargs['op']
+ rmin, rmax = op['rank']
+ rank_range = range(rmin, rmax + 1)
+
+ if check:
+ input_shape = kwargs['input_shape'].shape
+ output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
+
+ if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evChannelMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ChannelMismatch
+ param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input channel size not equal to output channel size"
+
+ assert 'op' in kwargs
+ op = kwargs['op']
+ rmin, rmax = op['rank']
+ rank_range = range(rmin, rmax + 1)
+
+ if check:
+ input_shape = kwargs['input_shape'].shape
+ output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
+ if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
def evStrideSmallerEqualZero(check=False, **kwargs):
error_name = ErrorIf.StrideSmallerEqualZero
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -2195,6 +2248,7 @@ class TosaTestGen:
):
result_tens = OutputShaper.resizeOp(
self.ser,
+ self.rng,
input,
mode,
stride,
@@ -2232,6 +2286,7 @@ class TosaTestGen:
stride_fp=stride_fp,
input_list=input_list,
output_list=output_list,
+ result_tensor=result_tens,
num_operands=num_operands,
)
@@ -3457,7 +3512,8 @@ class TosaTestGen:
"error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
- TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
},
# Type conversion
"cast": {
@@ -3862,7 +3918,8 @@ class OutputShaper:
@staticmethod
def resizeOp(
- ser,
+ serializer,
+ rng,
input,
mode,
stride,
@@ -3878,9 +3935,14 @@ class OutputShaper:
if error_name == ErrorIf.WrongRank:
output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
else:
- output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
+ if error_name == ErrorIf.BatchMismatch:
+ output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
+ elif error_name == ErrorIf.ChannelMismatch:
+ output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
+ else:
+ output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
- return ser.addOutput(output_dims, output_dtype)
+ return serializer.addOutput(output_dims, output_dtype)
@staticmethod
def typeConversionOp(ser, val, out_dtype):