aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-10-15 16:30:48 +0100
committerEric Kunze <eric.kunze@arm.com>2021-11-09 15:11:29 +0000
commit01c359de3ad4fc617ec7726fd7103749f6d3933f (patch)
treeff79f88dd3040238f58f500f23182da0b8f41a5f
parent630c17c5b46aed13edebc60321fcee5659c688bb (diff)
downloadreference_model-01c359de3ad4fc617ec7726fd7103749f6d3933f.tar.gz
Fix Transpose WrongRank test and add new test for Concat
* Transpose WrongRank tests now use ranks 7, 8 * Concat ERROR_IF checks now test for inaccurate summation of output shape tensor dimension Change-Id: If32f43a4dbd872d0ef7625fa3d4969c863a11b8c Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Signed-off-by: Les Bell <les.bell@arm.com>
-rw-r--r--verif/tosa_error_if.py1
-rw-r--r--verif/tosa_test_gen.py48
2 files changed, 46 insertions, 3 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index f0e752f..c3a9068 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -56,6 +56,7 @@ class ErrorIf(object):
MaxSmallerMin = "MaxSmallerMin"
ConcatInputRankMismatch = "ConcatInputRankMismatch"
ConcatInputDimMismatch = "ConcatInputDimMismatch"
+ ConcatShapeSumMismatch = "ConcatShapeSumMismatch"
CondIfInputListThenGraphMismatch = "CondIfInputListThenGraphMismatch"
CondIfInputListElseGraphMismatch = "CondIfInputListElseGraphMismatch"
CondIfOutputListThenGraphMismatch = "CondIfOutputListThenGraphMismatch"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 4e944ea..80ccff3 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1629,6 +1629,8 @@ class TosaErrorValidator:
# Set minimum incorrect rank to 3 to avoid index error
if op['op'] in [Op.RESIZE]:
incorrect_ranks = [3, 5]
+ if op['op'] in [Op.TRANSPOSE]:
+ incorrect_ranks = [7, 8]
error_name = ErrorIf.WrongRank
param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
@@ -2714,6 +2716,44 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evConcatShapeSumMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConcatShapeSumMismatch
+ param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Sum of dimensions on axis not equal to output dimension"
+
+ if check:
+ inputs = kwargs['inputs']
+ input_shape = kwargs['input_shape']
+ output_shape = kwargs['output_shape']
+ axis = kwargs['axis']
+
+ # Ensure rank is valid before checking dims.
+ valid_params = True
+ for input in inputs:
+ if len(input.shape) != len(input_shape):
+ valid_params = False
+ if axis < 0 or axis > len(input_shape):
+ valid_params = False
+
+ if valid_params:
+ axis_dim_sum = 0
+ for input in inputs:
+ axis_dim_sum += input.shape[axis]
+
+ if axis_dim_sum != output_shape[axis]:
+ error_result = True
+
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
def evInputListThenGraphMismatch(check=False, **kwargs):
error_name = ErrorIf.CondIfInputListThenGraphMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -5647,7 +5687,8 @@ class TosaTestGen:
"build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
"types": TYPE_FIB,
"error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
- TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
+ TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
},
"pad": {
"op": Op.PAD,
@@ -6173,12 +6214,13 @@ class OutputShaper:
error_name == ErrorIf.ConcatInputRankMismatch
# unable to concat tensors along an invalid axis
or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
- # unable to concat tensors of different dimensions
- or error_name == ErrorIf.ConcatInputDimMismatch
):
for tensor in remaining_inputs:
output_shape[axis] += tensor.shape[axis]
+ if error_name == ErrorIf.ConcatShapeSumMismatch:
+ output_shape[axis] += rng.integers(5, 10)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))