aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py48
1 files changed, 45 insertions, 3 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 4e944ea..80ccff3 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1629,6 +1629,8 @@ class TosaErrorValidator:
# Set minimum incorrect rank to 3 to avoid index error
if op['op'] in [Op.RESIZE]:
incorrect_ranks = [3, 5]
+ if op['op'] in [Op.TRANSPOSE]:
+ incorrect_ranks = [7, 8]
error_name = ErrorIf.WrongRank
param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
@@ -2714,6 +2716,44 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evConcatShapeSumMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConcatShapeSumMismatch
+ param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Sum of dimensions on axis not equal to output dimension"
+
+ if check:
+ inputs = kwargs['inputs']
+ input_shape = kwargs['input_shape']
+ output_shape = kwargs['output_shape']
+ axis = kwargs['axis']
+
+ # Ensure rank is valid before checking dims.
+ valid_params = True
+ for input in inputs:
+ if len(input.shape) != len(input_shape):
+ valid_params = False
+ if axis < 0 or axis > len(input_shape):
+ valid_params = False
+
+ if valid_params:
+ axis_dim_sum = 0
+ for input in inputs:
+ axis_dim_sum += input.shape[axis]
+
+ if axis_dim_sum != output_shape[axis]:
+ error_result = True
+
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
def evInputListThenGraphMismatch(check=False, **kwargs):
error_name = ErrorIf.CondIfInputListThenGraphMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -5647,7 +5687,8 @@ class TosaTestGen:
"build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
"types": TYPE_FIB,
"error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
- TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
},
"pad": {
"op": Op.PAD,
@@ -6173,12 +6214,13 @@ class OutputShaper:
error_name == ErrorIf.ConcatInputRankMismatch
# unable to concat tensors along an invalid axis
or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
- # unable to concat tensors of different dimensions
- or error_name == ErrorIf.ConcatInputDimMismatch
):
for tensor in remaining_inputs:
output_shape[axis] += tensor.shape[axis]
+ if error_name == ErrorIf.ConcatShapeSumMismatch:
+ output_shape[axis] += rng.integers(5, 10)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))