diff options
author | Matthew Haddon <matthew.haddon@arm.com> | 2021-10-15 16:30:48 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2021-11-09 15:11:29 +0000 |
commit | 01c359de3ad4fc617ec7726fd7103749f6d3933f (patch) | |
tree | ff79f88dd3040238f58f500f23182da0b8f41a5f /verif/tosa_test_gen.py | |
parent | 630c17c5b46aed13edebc60321fcee5659c688bb (diff) | |
download | reference_model-01c359de3ad4fc617ec7726fd7103749f6d3933f.tar.gz |
Fix Transpose WrongRank test and add new test for Concat
* Transpose WrongRank tests now use ranks 7, 8
* Concat ERROR_IF checks now test for inaccurate summation
of output shape tensor dimension
Change-Id: If32f43a4dbd872d0ef7625fa3d4969c863a11b8c
Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Signed-off-by: Les Bell <les.bell@arm.com>
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r-- | verif/tosa_test_gen.py | 48 |
1 files changed, 45 insertions, 3 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 4e944ea..80ccff3 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1629,6 +1629,8 @@ class TosaErrorValidator: # Set minimum incorrect rank to 3 to avoid index error if op['op'] in [Op.RESIZE]: incorrect_ranks = [3, 5] + if op['op'] in [Op.TRANSPOSE]: + incorrect_ranks = [7, 8] error_name = ErrorIf.WrongRank param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None} @@ -2714,6 +2716,44 @@ class TosaErrorValidator: return info_dict @staticmethod + def evConcatShapeSumMismatch(check=False, **kwargs): + error_name = ErrorIf.ConcatShapeSumMismatch + param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Sum of dimensions on axis not equal to output dimension" + + if check: + inputs = kwargs['inputs'] + input_shape = kwargs['input_shape'] + output_shape = kwargs['output_shape'] + axis = kwargs['axis'] + + # Ensure rank is valid before checking dims. + valid_params = True + for input in inputs: + if len(input.shape) != len(input_shape): + valid_params = False + if axis < 0 or axis > len(input_shape): + valid_params = False + + if valid_params: + axis_dim_sum = 0 + for input in inputs: + axis_dim_sum += input.shape[axis] + + if axis_dim_sum != output_shape[axis]: + error_result = True + + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod def evInputListThenGraphMismatch(check=False, **kwargs): error_name = ErrorIf.CondIfInputListThenGraphMismatch param_reqs = {"rank": None, "dtype": None, "shape": None} @@ -5647,7 +5687,8 @@ class TosaTestGen: "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis), "types": TYPE_FIB, "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch, - TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList) }, "pad": { "op": Op.PAD, @@ -6173,12 +6214,13 @@ class OutputShaper: error_name == ErrorIf.ConcatInputRankMismatch # unable to concat tensors along an invalid axis or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero] - # unable to concat tensors of different dimensions - or error_name == ErrorIf.ConcatInputDimMismatch ): for tensor in remaining_inputs: output_shape[axis] += tensor.shape[axis] + if error_name == ErrorIf.ConcatShapeSumMismatch: + output_shape[axis] += rng.integers(5, 10) + if error_name == ErrorIf.WrongOutputType: all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT} wrong_dtypes = list(all_dtypes - set([input1.dtype])) |