diff options
-rw-r--r-- | verif/tosa_error_if.py | 1 | ||||
-rw-r--r-- | verif/tosa_test_gen.py | 48 |
2 files changed, 46 insertions, 3 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py index f0e752f..c3a9068 100644 --- a/verif/tosa_error_if.py +++ b/verif/tosa_error_if.py @@ -56,6 +56,7 @@ class ErrorIf(object): MaxSmallerMin = "MaxSmallerMin" ConcatInputRankMismatch = "ConcatInputRankMismatch" ConcatInputDimMismatch = "ConcatInputDimMismatch" + ConcatShapeSumMismatch = "ConcatShapeSumMismatch" CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch" CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch" CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch" diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 4e944ea..80ccff3 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1629,6 +1629,8 @@ class TosaErrorValidator: # Set minimum incorrect rank to 3 to avoid index error if op['op'] in [Op.RESIZE]: incorrect_ranks = [3, 5] + if op['op'] in [Op.TRANSPOSE]: + incorrect_ranks = [7, 8] error_name = ErrorIf.WrongRank param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None} @@ -2714,6 +2716,44 @@ class TosaErrorValidator: return info_dict @staticmethod + def evConcatShapeSumMismatch(check=False, **kwargs): + error_name = ErrorIf.ConcatShapeSumMismatch + param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + error_result = False + error_reason = "Sum of dimensions on axis not equal to output dimension" + + if check: + inputs = kwargs['inputs'] + input_shape = kwargs['input_shape'] + output_shape = kwargs['output_shape'] + axis = kwargs['axis'] + + # Ensure rank is valid before checking dims. + valid_params = True + for input in inputs: + if len(input.shape) != len(input_shape): + valid_params = False + if axis < 0 or axis > len(input_shape): + valid_params = False + + if valid_params: + axis_dim_sum = 0 + for input in inputs: + axis_dim_sum += input.shape[axis] + + if axis_dim_sum != output_shape[axis]: + error_result = True + + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs + } + return info_dict + + @staticmethod def evInputListThenGraphMismatch(check=False, **kwargs): error_name = ErrorIf.CondIfInputListThenGraphMismatch param_reqs = {"rank": None, "dtype": None, "shape": None} @@ -5647,7 +5687,8 @@ class TosaTestGen: "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis), "types": TYPE_FIB, "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch, - TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList) + TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList) }, "pad": { "op": Op.PAD, @@ -6173,12 +6214,13 @@ class OutputShaper: error_name == ErrorIf.ConcatInputRankMismatch # unable to concat tensors along an invalid axis or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero] - # unable to concat tensors of different dimensions - or error_name == ErrorIf.ConcatInputDimMismatch ): for tensor in remaining_inputs: output_shape[axis] += tensor.shape[axis] + if error_name == ErrorIf.ConcatShapeSumMismatch: + output_shape[axis] += rng.integers(5, 10) + if error_name == ErrorIf.WrongOutputType: all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT} wrong_dtypes = list(all_dtypes - set([input1.dtype])) |