aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-08-24 17:09:09 +0100
committerEric Kunze <eric.kunze@arm.com>2022-08-25 15:27:20 +0000
commitd32c6da91647dd09d3f22483ded8157941a5ade9 (patch)
tree45c944b94bc62ae9117bb32bb368d675ebaa7fe4
parentd511f9e604c3e2b915d6f6b7a4975b23ac06041d (diff)
downloadreference_model-d32c6da91647dd09d3f22483ded8157941a5ade9.tar.gz
Add PAD ERROR_IF test for output shape
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I25a13540734fa30c0c21b46708dfabbec8c4d1e5
-rw-r--r--verif/generator/tosa_error_if.py25
-rw-r--r--verif/generator/tosa_test_gen.py5
2 files changed, 30 insertions, 0 deletions
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 1651d95..e4e60b7 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -43,6 +43,7 @@ class ErrorIf(object):
DilationSmallerOne = "DilationSmallerOne"
PadSmallerZero = "PadSmallerZero"
PadLargerEqualKernel = "PadLargerEqualKernel"
+ PadOutputShapeMismatch = "PadOutputShapeMismatch"
PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger"
ConvOutputShapeMismatch = "ConvOutputShapeMismatch"
@@ -1279,6 +1280,30 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evPadOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.PadOutputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Pad output shape mismatch for requested padding"
+
+ if check:
+ pad = kwargs["pad"]
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ for dim, padding in enumerate(pad):
+ expected_size = input_shape[dim] + padding[0] + padding[1]
+ if expected_size != output_shape[dim]:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def checkPoolingParams(kernel, stride, pad):
return (
min(kernel) >= 1
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index eeb0ac7..53d38dd 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -3428,6 +3428,7 @@ class TosaTestGen:
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evPadOutputShapeMismatch,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
@@ -4262,6 +4263,10 @@ class OutputShaper:
for i in range(len(output_shape)):
output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i]
+ if error_name == ErrorIf.PadOutputShapeMismatch:
+ bad_dim = rng.choice(range(len(output_shape)))
+ output_shape[bad_dim] -= rng.choice([1, 2])
+
# Fix negative output shape if error_if test causes it
if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
output_shape = [i if i >= 1 else 1 for i in output_shape]