From 01c359de3ad4fc617ec7726fd7103749f6d3933f Mon Sep 17 00:00:00 2001 From: Matthew Haddon Date: Fri, 15 Oct 2021 16:30:48 +0100 Subject: Fix Transpose WrongRank test and add new test for Concat * Transpose WrongRank tests now use ranks 7, 8 * Concat ERROR_IF checks now test for inaccurate summation of output shape tensor dimension Change-Id: If32f43a4dbd872d0ef7625fa3d4969c863a11b8c Signed-off-by: Matthew Haddon Signed-off-by: Jeremy Johnson Signed-off-by: Les Bell --- verif/tosa_error_if.py | 1 + verif/tosa_test_gen.py | 48 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py index f0e752f..c3a9068 100644 --- a/verif/tosa_error_if.py +++ b/verif/tosa_error_if.py @@ -56,6 +56,7 @@ class ErrorIf(object): MaxSmallerMin = "MaxSmallerMin" ConcatInputRankMismatch = "ConcatInputRankMismatch" ConcatInputDimMismatch = "ConcatInputDimMismatch" + ConcatShapeSumMismatch = "ConcatShapeSumMismatch" CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch" CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch" CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch" diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 4e944ea..80ccff3 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1629,6 +1629,8 @@ class TosaErrorValidator: # Set minimum incorrect rank to 3 to avoid index error if op['op'] in [Op.RESIZE]: incorrect_ranks = [3, 5] + if op['op'] in [Op.TRANSPOSE]: + incorrect_ranks = [7, 8] error_name = ErrorIf.WrongRank param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None} @@ -2713,6 +2715,44 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evConcatShapeSumMismatch(check=False, **kwargs): + error_name = ErrorIf.ConcatShapeSumMismatch + param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Sum of dimensions on axis not equal to output dimension" + + if check: + inputs = kwargs['inputs'] + input_shape = kwargs['input_shape'] + output_shape = kwargs['output_shape'] + axis = kwargs['axis'] + + # Ensure rank is valid before checking dims. + valid_params = True + for input in inputs: + if len(input.shape) != len(input_shape): + valid_params = False + if axis < 0 or axis > len(input_shape): + valid_params = False + + if valid_params: + axis_dim_sum = 0 + for input in inputs: + axis_dim_sum += input.shape[axis] + + if axis_dim_sum != output_shape[axis]: + error_result = True + + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + @staticmethod def evInputListThenGraphMismatch(check=False, **kwargs): error_name = ErrorIf.CondIfInputListThenGraphMismatch @@ -5647,7 +5687,8 @@ class TosaTestGen: "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis), "types": TYPE_FIB, "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch, - TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList) }, "pad": { "op": Op.PAD, @@ -6173,12 +6214,13 @@ class OutputShaper: error_name == ErrorIf.ConcatInputRankMismatch # unable to concat tensors along an invalid axis or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero] - # unable to concat tensors of different dimensions - or error_name == ErrorIf.ConcatInputDimMismatch ): for tensor in remaining_inputs: output_shape[axis] += tensor.shape[axis] + if error_name == ErrorIf.ConcatShapeSumMismatch: + output_shape[axis] += rng.integers(5, 10) + if error_name == ErrorIf.WrongOutputType: all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT} wrong_dtypes = list(all_dtypes - set([input1.dtype])) -- cgit v1.2.1