diff options
author | Matthew Haddon <matthew.haddon@arm.com> | 2021-09-22 11:24:37 +0100 |
---|---|---|
committer | Matthew Haddon <matthew.haddon@arm.com> | 2021-10-07 17:24:19 +0100 |
commit | 693ba9ed076e3b9e95e484a27b087352d2bac157 (patch) | |
tree | a32cd47cc4facc00bf821e6618e5cbbd5f51a833 /verif | |
parent | 1c00b71c44c80b2433e1837af204317636a69f95 (diff) | |
download | reference_model-693ba9ed076e3b9e95e484a27b087352d2bac157.tar.gz |
Add ERROR_IF checks for mismatched batch/channel
Change-Id: I7c670c5f9b97a18a6f586b16f31bc9fc301f6bc3
Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Diffstat (limited to 'verif')
-rw-r--r-- | verif/tosa_error_if.py | 2 | ||||
-rw-r--r-- | verif/tosa_test_gen.py | 70 |
2 files changed, 68 insertions, 4 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py index 58595d3..94648d3 100644 --- a/verif/tosa_error_if.py +++ b/verif/tosa_error_if.py @@ -27,4 +27,6 @@ class ErrorIf(object): WrongInputList = "WrongInputList" WrongOutputList = "WrongOutputList" WrongRank = "WrongRank" + BatchMismatch = "BatchMismatch" + ChannelMismatch = "ChannelMismatch" diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 2c13172..3cd1d69 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1347,6 +1347,59 @@ class TosaErrorValidator: return info_dict @staticmethod + def evBatchMismatch(check=False, **kwargs): + error_name = ErrorIf.BatchMismatch + param_reqs = {"rank": [4,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Input batch size not equal to output batch size" + + assert 'op' in kwargs + op = kwargs['op'] + rmin, rmax = op['rank'] + rank_range = range(rmin, rmax + 1) + + if check: + input_shape = kwargs['input_shape'].shape + output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C) + + if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod + def evChannelMismatch(check=False, **kwargs): + error_name = ErrorIf.ChannelMismatch + param_reqs = {"rank": [4,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Input channel size not equal to output channel size" + + assert 'op' in kwargs + op = kwargs['op'] + rmin, rmax = op['rank'] + rank_range = range(rmin, rmax + 1) + + if check: + input_shape = kwargs['input_shape'].shape + output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C) + if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod def evStrideSmallerEqualZero(check=False, **kwargs): error_name = ErrorIf.StrideSmallerEqualZero param_reqs = {"rank": None, "dtype": None, "shape": None} @@ -2195,6 +2248,7 @@ class TosaTestGen: ): result_tens = OutputShaper.resizeOp( self.ser, + self.rng, input, mode, stride, @@ -2232,6 +2286,7 @@ class TosaTestGen: stride_fp=stride_fp, input_list=input_list, output_list=output_list, + result_tensor=result_tens, num_operands=num_operands, ) @@ -3457,7 +3512,8 @@ class TosaTestGen: "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension, TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax, TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch) }, # Type conversion "cast": { @@ -3862,7 +3918,8 @@ class OutputShaper: @staticmethod def resizeOp( - ser, + serializer, + rng, input, mode, stride, @@ -3878,9 +3935,14 @@ class OutputShaper: if error_name == ErrorIf.WrongRank: output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]] else: - output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]] + if error_name == ErrorIf.BatchMismatch: + output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]] + elif error_name == ErrorIf.ChannelMismatch: + output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)] + else: + output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]] - return ser.addOutput(output_dims, output_dtype) + return serializer.addOutput(output_dims, output_dtype) @staticmethod def typeConversionOp(ser, val, out_dtype): |