From ebf4af6a45c60d3f75ccd6019612a7f8b6552d72 Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Wed, 27 Jan 2021 15:57:57 +0100 Subject: MLBEDSW-3903: Bug fix PAD operator - Added checks for unsupported pad sizes in PAD operator - Bug fix right pad/bottom pad calculation when replacing PAD operator by hardware padding Change-Id: Ib84be711277d987052f14352ab386e0e0b774987 Signed-off-by: Louis Verhaard --- ethosu/vela/graph_optimiser.py | 41 ++++++++++++++++++---------- ethosu/vela/operation.py | 7 ++++- ethosu/vela/supported_operators.py | 36 +++++++++++++++++++++++- ethosu/vela/test/test_graph_optimiser.py | 34 +++++++++++++++++++++++ ethosu/vela/test/test_supported_operators.py | 40 +++++++++++++++++++++++++-- 5 files changed, 139 insertions(+), 19 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 7755cc3b..2d47c262 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -18,6 +18,7 @@ # split into two parts optimise_graph_a and optimise_graph_b. import math import uuid +from typing import Tuple import numpy as np @@ -183,9 +184,26 @@ def needed_total_padding(input_size, stride, filter_size): return total_padding -def calc_padding_and_skirt(padding_type, kernel_size, stride, input_shape, explicit_padding): - ypad = needed_total_padding(int(input_shape.height), int(stride[1]), int(kernel_size[0])) - xpad = needed_total_padding(int(input_shape.width), int(stride[2]), int(kernel_size[1])) +def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]: + """ + Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding + that provides equivalent results. + """ + total_padding = needed_total_padding(input_size, stride, filter_size) + # The top/left padding can be taken as is from the PAD + output_pad_before = pad_before + # The bottom/right padding might need downward adjustment depending on stride/input size + output_pad_after = pad_after + while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride: + output_pad_after -= 1 + return output_pad_before, output_pad_after + + +def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding): + k_w, k_h = kernel.dilated_wh() + s_x, s_y = kernel.stride + ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h)) + xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w)) if padding_type == Padding.SAME: left_pad = (xpad + 0) // 2 right_pad = (xpad + 1) // 2 @@ -198,10 +216,9 @@ def calc_padding_and_skirt(padding_type, kernel_size, stride, input_shape, expli bottom_pad = 0 elif padding_type == Padding.EXPLICIT: # Padding is specified in a PAD operator which has been bypassed. - # The top and left padding are taken from the PAD; bottom and right are calculated. - top_pad, left_pad, _, _ = explicit_padding - bottom_pad = ypad - top_pad - right_pad = xpad - left_pad + top, left, bottom, right = explicit_padding + top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom)) + left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right)) else: raise UnsupportedFeatureError(f"Unknown padding") padding = (top_pad, left_pad, bottom_pad, right_pad) @@ -495,14 +512,8 @@ def add_padding_fields(op, arch, nng): op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor ) else: - dilation_h, dilation_w = op.get_dilation_h_w() - dilated_kernel_size = [dilation_h * (kernel_size[0] - 1) + 1, dilation_w * (kernel_size[1] - 1) + 1] padding, skirt = calc_padding_and_skirt( - op.attrs["padding"], - dilated_kernel_size, - op.attrs["strides"], - input_shape, - op.attrs.get("explicit_padding"), + op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"), ) op.attrs["explicit_padding"] = padding diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 73953cec..963d9e69 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -22,6 +22,7 @@ from typing import Any from typing import Dict from typing import List from typing import Optional +from typing import Tuple from typing import TYPE_CHECKING from .errors import VelaError @@ -68,6 +69,10 @@ class Kernel: def area_height(self) -> int: return (self.height - 1) * self.dilation.y + 1 + def dilated_wh(self) -> Tuple[int, int]: + """Returns the dilated kernel width/height""" + return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1 + def __str__(self): return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}" diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 99a4ba10..505d4d16 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -260,6 +260,7 @@ class SupportedOperators: self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_type) self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant) self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm) + self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_size) # HardSwish specific checks: self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit) @@ -843,6 +844,39 @@ class SupportedOperators: valid = len(unsupported_consumers) == 0 return valid, f"PAD operator is followed by: {_optype_formatter(unsupported_consumers)+none_string}" + @staticmethod + def __leading_pad_ok(leading_pad, stride, kernel_size): + # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride, + # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns + max_size = kernel_size // 2 + return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0 + + @staticmethod + def constraint_pad_size(op): + "Padding must be at most kernel size divided by 2" + if SupportedOperators.constraint_pad_ofm(op)[0]: + padding = op.inputs[1].values # 4x2 tensor, first dimension is N, H, W, C + top, left, bottom, right = (padding[1][0], padding[2][0], padding[1][1], padding[2][1]) + for cons in op.ofm.consumers(): + if cons is not None: + # Note: pre-order graph traversal removes inputs of operators that are in traversal, + # which makes it impossible to calculate kernel size, hence use cached _kernel for those operators + k = cons.kernel if cons.inputs else cons._kernel + k_w, k_h = k.dilated_wh() + if left > k_w // 2: + return False, f"Left padding is {left}, kernel width is {k_w}" + if right > k_w // 2: + return False, f"Right padding is {right}, kernel width is {k_w}" + if top > k_h // 2: + return False, f"Top padding is {top}, kernel height is {k_h}" + if bottom > k_h // 2: + return False, f"Bottom padding is {bottom}, kernel height is {k_h}" + if not SupportedOperators.__leading_pad_ok(top, k.stride.y, k_h): + return False, f"Top padding is {top}, must be {k_h // 2} or multiple of {k.stride.y}" + if not SupportedOperators.__leading_pad_ok(left, k.stride.x, k_w): + return False, f"Left padding is {left}, must be {k_w // 2} or multiple of {k.stride.x}" + return True, "Pad size is ok" + @staticmethod def constraint_stridedslice_inputs_const(op): "Begin, End and Stride Input tensors must be constant" diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py index 55980e3d..4281d314 100644 --- a/ethosu/vela/test/test_graph_optimiser.py +++ b/ethosu/vela/test/test_graph_optimiser.py @@ -17,8 +17,10 @@ # Description: # Unit tests for graph_optimiser import numpy as np +import pytest from ethosu.vela.data_type import DataType +from ethosu.vela.graph_optimiser import calc_explicit_padding from ethosu.vela.graph_optimiser import convert_batched_fc_shape from ethosu.vela.graph_optimiser import optimise_graph_a from ethosu.vela.graph_optimiser import optimise_pad @@ -82,6 +84,38 @@ def test_convert_batched_fc(): assert conv_op.ifm.shape == conv_op.ofm.shape +explicit_padding_test_data = [ + # Kernel size 2 + [(17, 1, 2, 1, 1), (1, 1)], + [(18, 1, 2, 0, 1), (0, 1)], + [(18, 1, 2, 1, 0), (1, 0)], + # Kernel size 3 + [(18, 2, 3, 1, 1), (1, 0)], + [(25, 2, 3, 1, 1), (1, 1)], + # Kernel size 4 + [(18, 1, 4, 1, 2), (1, 2)], + [(18, 1, 4, 2, 1), (2, 1)], + [(19, 1, 4, 2, 2), (2, 2)], + # Kernel size 5 + [(19, 1, 5, 1, 2), (1, 2)], + [(19, 1, 5, 0, 2), (0, 2)], + [(19, 1, 5, 1, 0), (1, 0)], + # Kernel size 21 + [(41, 2, 21, 8, 10), (8, 10)], + [(41, 3, 21, 10, 10), (10, 9)], + [(42, 3, 21, 10, 10), (10, 8)], + [(42, 3, 21, 9, 10), (9, 9)], + [(41, 3, 21, 10, 6), (10, 6)], +] + + +@pytest.mark.parametrize("test_input, expected_result", explicit_padding_test_data) +def test_calc_explicit_padding(test_input, expected_result): + input_size, stride, filter_size, explicit_pad_before, explicit_pad_after = test_input + before, after = calc_explicit_padding(input_size, stride, filter_size, explicit_pad_before, explicit_pad_after) + assert (before, after) == expected_result + + def test_optimise_pad(): """ Tests that the PAD operator is bypassed when followed by a convolution operator, diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 5c01027d..5f64dd9d 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -17,6 +17,7 @@ # Description: # Unit tests for support_operators import numpy as np +import pytest from ethosu.vela.data_type import DataType from ethosu.vela.operation import ActivationFunction @@ -525,6 +526,7 @@ def create_pad_op( out_dtype=DataType.int8, pad_dtype=DataType.int32, pad_setting=Padding.VALID, + kernel_size=3, ): qp = testutil.default_quant_params() in0 = Tensor(in_shape, in_dtype, "in") @@ -535,7 +537,7 @@ def create_pad_op( op = testutil.create_op(Op.Pad, [in0, pad_tensor], out) conv_out_tens = Tensor(in_shape, in_dtype, "output") conv_out_tens.quantization = qp.clone() - weight_tens = Tensor(in_shape, in_dtype, "weights") + weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights") weight_tens.values = np.zeros(weight_tens.shape) weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8) weight_tens.quantization = qp.clone() @@ -609,6 +611,40 @@ def test_constraint_pad_consumer(): assert not support.is_operator_supported(op) +pad_invalid_size_test_data = [ + (2, 1, 1, 1), + (1, 2, 1, 1), + (1, 1, 2, 1), + (1, 1, 1, 2), +] + + +@pytest.mark.parametrize("top, left, bottom, right", pad_invalid_size_test_data) +def test_constraint_pad_size(top, left, bottom, right): + # Tests PAD operator with a padding that is too high to be handled by the NPU + out_shape = [1, 11 + left + right, 11 + top + bottom, 1] + padding = [[0, 0], [top, bottom], [left, right], [0, 0]] + op = create_pad_op(in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding,) + assert not support.is_operator_supported(op) + + +leading_pad_test_data = [ + (2, 2, 11, True), + (1, 2, 11, False), + (2, 1, 11, False), + (5, 2, 11, True), +] + + +@pytest.mark.parametrize("top, left, kernel_size, expected", leading_pad_test_data) +def test_constraint_leading_pad_size(top, left, kernel_size, expected): + # Tests PAD operator with big kernel size; top and left pad must be multiple of stride + out_shape = [1, 11 + left, 11 + top, 1] + padding = [[0, 0], [top, 0], [left, 0], [0, 0]] + op = create_pad_op(in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size) + assert support.is_operator_supported(op) == expected + + def create_strided_slice(): # Creates a valid strided slice operator with some valid inputs/outputs op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0]) -- cgit v1.2.1