From d32c6da91647dd09d3f22483ded8157941a5ade9 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 24 Aug 2022 17:09:09 +0100 Subject: Add PAD ERROR_IF test for output shape Signed-off-by: Jeremy Johnson Change-Id: I25a13540734fa30c0c21b46708dfabbec8c4d1e5 --- verif/generator/tosa_error_if.py | 25 +++++++++++++++++++++++++ verif/generator/tosa_test_gen.py | 5 +++++ 2 files changed, 30 insertions(+) diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 1651d95..e4e60b7 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -43,6 +43,7 @@ class ErrorIf(object): DilationSmallerOne = "DilationSmallerOne" PadSmallerZero = "PadSmallerZero" PadLargerEqualKernel = "PadLargerEqualKernel" + PadOutputShapeMismatch = "PadOutputShapeMismatch" PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch" PoolingOutputShapeNonInteger = "PoolingOutputShapeNonInteger" ConvOutputShapeMismatch = "ConvOutputShapeMismatch" @@ -1278,6 +1279,30 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evPadOutputShapeMismatch(check=False, **kwargs): + error_name = ErrorIf.PadOutputShapeMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Pad output shape mismatch for requested padding" + + if check: + pad = kwargs["pad"] + input_shape = kwargs["input_shape"] + output_shape = kwargs["output_shape"] + for dim, padding in enumerate(pad): + expected_size = input_shape[dim] + padding[0] + padding[1] + if expected_size != output_shape[dim]: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + @staticmethod def checkPoolingParams(kernel, stride, pad): return ( diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index eeb0ac7..53d38dd 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -3428,6 +3428,7 @@ class TosaTestGen: "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero, + TosaErrorValidator.evPadOutputShapeMismatch, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, @@ -4262,6 +4263,10 @@ class OutputShaper: for i in range(len(output_shape)): output_shape[i] = padding[i][0] + padding[i][1] + output_shape[i] + if error_name == ErrorIf.PadOutputShapeMismatch: + bad_dim = rng.choice(range(len(output_shape))) + output_shape[bad_dim] -= rng.choice([1, 2]) + # Fix negative output shape if error_if test causes it if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1: output_shape = [i if i >= 1 else 1 for i in output_shape] -- cgit v1.2.1