From 693ba9ed076e3b9e95e484a27b087352d2bac157 Mon Sep 17 00:00:00 2001 From: Matthew Haddon Date: Wed, 22 Sep 2021 11:24:37 +0100 Subject: Add ERROR_IF checks for mismatched batch/channel Change-Id: I7c670c5f9b97a18a6f586b16f31bc9fc301f6bc3 Signed-off-by: Matthew Haddon --- verif/tosa_error_if.py | 2 ++ verif/tosa_test_gen.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py index 58595d3..94648d3 100644 --- a/verif/tosa_error_if.py +++ b/verif/tosa_error_if.py @@ -27,4 +27,6 @@ class ErrorIf(object): WrongInputList = "WrongInputList" WrongOutputList = "WrongOutputList" WrongRank = "WrongRank" + BatchMismatch = "BatchMismatch" + ChannelMismatch = "ChannelMismatch" diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 2c13172..3cd1d69 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1346,6 +1346,59 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evBatchMismatch(check=False, **kwargs): + error_name = ErrorIf.BatchMismatch + param_reqs = {"rank": [4,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Input batch size not equal to output batch size" + + assert 'op' in kwargs + op = kwargs['op'] + rmin, rmax = op['rank'] + rank_range = range(rmin, rmax + 1) + + if check: + input_shape = kwargs['input_shape'].shape + output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C) + + if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod + def evChannelMismatch(check=False, **kwargs): + error_name = ErrorIf.ChannelMismatch + param_reqs = {"rank": [4,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Input channel size not equal to output channel size" + + assert 'op' in kwargs + op = kwargs['op'] + rmin, rmax = op['rank'] + rank_range = range(rmin, rmax + 1) + + if check: + input_shape = kwargs['input_shape'].shape + output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C) + if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]): + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + @staticmethod def evStrideSmallerEqualZero(check=False, **kwargs): error_name = ErrorIf.StrideSmallerEqualZero @@ -2195,6 +2248,7 @@ class TosaTestGen: ): result_tens = OutputShaper.resizeOp( self.ser, + self.rng, input, mode, stride, @@ -2232,6 +2286,7 @@ class TosaTestGen: stride_fp=stride_fp, input_list=input_list, output_list=output_list, + result_tensor=result_tens, num_operands=num_operands, ) @@ -3457,7 +3512,8 @@ class TosaTestGen: "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension, TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax, TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch) }, # Type conversion "cast": { @@ -3862,7 +3918,8 @@ class OutputShaper: @staticmethod def resizeOp( - ser, + serializer, + rng, input, mode, stride, @@ -3878,9 +3935,14 @@ class OutputShaper: if error_name == ErrorIf.WrongRank: output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]] else: - output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]] + if error_name == ErrorIf.BatchMismatch: + output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]] + elif error_name == ErrorIf.ChannelMismatch: + output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)] + else: + output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]] - return ser.addOutput(output_dims, output_dtype) + return serializer.addOutput(output_dims, output_dtype) @staticmethod def typeConversionOp(ser, val, out_dtype): -- cgit v1.2.1