aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2021-02-09 16:08:26 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2021-02-17 16:29:20 +0100
commit1a92f78e14f31f1423824228deb0628b7a9a9071 (patch)
tree4e0de9a2e6b5b7d9159b25cbfead4a625c134a3c
parent8d0f4890aa0ceae92a33ebb789701ff644a6fcaa (diff)
downloadethos-u-vela-1a92f78e14f31f1423824228deb0628b7a9a9071.tar.gz
MLBEDSW-4022: support PAD followed by pool operator
PAD followed by max/average pool is run on NPU if NPU padding can be used. Average pool is converted to depthwise. Change-Id: Icc3652e6d9ecff5ac3dc7d92080313d90c245404 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--SUPPORTED_OPS.md5
-rw-r--r--ethosu/vela/graph_optimiser.py44
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py3
-rw-r--r--ethosu/vela/operation.py4
-rw-r--r--ethosu/vela/softmax.py12
-rw-r--r--ethosu/vela/supported_operators.py37
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py49
-rw-r--r--ethosu/vela/test/test_supported_operators.py58
8 files changed, 177 insertions, 35 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index dfa24d0..20134cc 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
# Supported Ops
This file was automatically generated by Vela using the `--supported-ops-report` parameter.
-Vela version: `2.0.2.dev49+gda756aa`
+Vela version: `2.0.2.dev69+g83e3bb3.d20210212`
This file complies with
[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -63,6 +63,7 @@ This is a list of constraints that all NPU operators must satisfy in order to be
- Input(s), Output and Weight tensors with quantization scales must be finite
- Per-axis quantization is only supported for the following op types: CONV_2D, DEPTHWISE_CONV_2D, TRANSPOSE_CONV
- The fused activation function (if present) must be one of type: LOGISTIC, RELU, RELU6, RELU_N1_TO_1, TANH
+- If a fused activation function is present, the Output tensor must be one of type: int16, int8, uint8
- Input and Output tensors must have quantization scales that fit within float32 precision
## ABS Constraints
@@ -221,7 +222,7 @@ This is a list of constraints that the PAD operator must satisfy in order to be
- The pad tensor can only pad width and height
- Pad tensor must be of type: int32, int64
- The padding tensor must be constant
-- Must be followed by one of the following operator types: CONV_2D, DEPTHWISE_CONV_2D
+- Must be followed by one of the following operator types: AVERAGE_POOL_2D, CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D
- Padding must be at most kernel size divided by 2
## RESIZE_BILINEAR Constraints
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index f5006c6..e1ceb9f 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -26,6 +26,7 @@ from . import fp_math
from . import lut
from . import rewrite_graph
from . import scaling
+from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
@@ -46,6 +47,7 @@ from .tensor import check_quantized_tens_scaling_equal
from .tensor import create_const_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
+from .tensor import TensorPurpose
from .tflite_mapping import optype_to_builtintype
passthrough_nodes = (Op.Identity,)
@@ -1174,19 +1176,55 @@ def fuse_activation_function_with_prev(op, arch, nng):
return op
-def optimise_pad(op, arch, nng):
+def optimise_pad(op: Operation, arch, nng):
"""
Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
if both operations can be run on the NPU.
"""
if (
- (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op())
+ (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_pool_op())
and op.run_on_npu
and op.attrs["padding"] == Padding.VALID
):
pad_op = op.ifm.ops[0]
if pad_op.type != Op.Pad or not pad_op.run_on_npu:
return op
+ if op.type.is_avgpool_op():
+ # Average pool is converted to depthwise, because NPU average pool + same padding
+ # has a special implementation that is different from PAD followed by average pool with
+ # valid padding.
+ k_w, k_h = op.kernel.width, op.kernel.height
+ ifm = op.ifm
+ # Remember other inputs
+ other_inputs = op.inputs[1:]
+ # Create a weight tensor, all weights are set to 1/(kernel width * kernel height)
+ quantization = QuantizationParameters(0.0, 255.0)
+ quantization.scale_f32 = 1.0 / (k_w * k_h)
+ quantization.zero_point = 0
+ shape = [k_h, k_w, 1, op.ofm.shape[-1]]
+ weights = np.full(shape, 1)
+
+ weight_tens = create_const_tensor(
+ op.name + "_weights",
+ shape,
+ op.ifm.dtype,
+ weights,
+ np.uint8,
+ purpose=TensorPurpose.Weights,
+ quantization=quantization,
+ )
+ weight_tens.quant_values = weights
+ op.type = Op.DepthwiseConv2DBias
+ op.inputs = []
+ op.add_input_tensor(ifm)
+ op.add_input_tensor(weight_tens)
+ # Add bias tensor, all biases set to 0
+ op.inputs.append(None)
+ fixup_bias_tensors(op, arch, nng)
+ # Add other inputs
+ op.inputs.extend(other_inputs)
+ op.rounding_mode = NpuRoundingMode.NATURAL
+
# Bypass the PAD operator
op.set_input_tensor(pad_op.ifm, 0)
# Adjust the padding attributes of the convolution operator
@@ -1231,7 +1269,7 @@ def fixup_bias_tensors(op, arch, nng):
bias_values = [0] * nr_biases
bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
bias_tensor.quant_values = bias_tensor.values
- op.set_input_tensor(bias_tensor, -1)
+ op.set_input_tensor(bias_tensor, op.type.info.indices.biases[0])
return op
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index b5e7b4b..1059e6e 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -133,7 +133,8 @@ def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
and op.kernel.elements_wh() == 1
):
rounding_mode = NpuRoundingMode.NATURAL
- rounding_mode = op.attrs.get("rounding_mode", rounding_mode)
+ if op.rounding_mode is not None:
+ rounding_mode = op.rounding_mode
return rounding_mode
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index e4d11be..967d30b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -25,6 +25,7 @@ from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
+from .api import NpuRoundingMode
from .errors import VelaError
from .numeric_util import full_shape
from .shape4d import Shape4D
@@ -420,6 +421,7 @@ class Operation:
"ofm_shapes",
"rescale",
"read_offsets",
+ "rounding_mode",
)
def __init__(self, op_type: Op, name: str):
@@ -448,6 +450,7 @@ class Operation:
# (which overrides the ofm tensor's scale)
self.rescale = None
self.read_offsets: List[Shape4D] = [None, None] # offset for [ifm, ifm2]
+ self.rounding_mode: Optional[NpuRoundingMode] = None
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -464,6 +467,7 @@ class Operation:
res.scheduled_pass = self.scheduled_pass
res.op_index = None # not relevant as not part of input network
res.read_offsets = list(self.read_offsets)
+ res.rounding_mode = self.rounding_mode
return res
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 4418f01..520ec23 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
@@ -287,11 +287,9 @@ class SoftMax:
shift = create_const_tensor(
f"{name}_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
)
- rescaled_exp = add_op_get_ofm(
- create_shr(
- name, ifm_exp, shift, no_scale_quant, activation, attrs={"rounding_mode": NpuRoundingMode.NATURAL},
- )
- )
+ shr_op = create_shr(name, ifm_exp, shift, no_scale_quant, activation)
+ shr_op.rounding_mode = NpuRoundingMode.NATURAL
+ rescaled_exp = add_op_get_ofm(shr_op)
# PASS 3 - Reduce sum
sum_of_exp = add_op_get_ofm(
@@ -421,7 +419,7 @@ class SoftMax:
# PASS 30 - SHR
shr30_op = Operation(Op.SHR, f"{self.op.name}_shr{pass_number}")
- shr30_op.attrs["rounding_mode"] = NpuRoundingMode.NATURAL
+ shr30_op.rounding_mode = NpuRoundingMode.NATURAL
shr30_op.add_input_tensor(scaled_exp)
shr30_op.add_input_tensor(right_shift)
shr30_op.set_output_tensor(ofm)
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 8446ec2..84432c7 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -119,7 +119,7 @@ class SupportedOperators:
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
# Supported consumers
- supported_pad_consumers = convolution_ops | depthwise_convolution_ops
+ supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -878,18 +878,29 @@ class SupportedOperators:
# which makes it impossible to calculate kernel size, hence use cached _kernel for those operators
k = cons.kernel if cons.inputs else cons._kernel
k_w, k_h = k.dilated_wh()
- if left > k_w // 2:
- return False, f"Left padding is {left}, kernel width is {k_w}"
- if right > k_w // 2:
- return False, f"Right padding is {right}, kernel width is {k_w}"
- if top > k_h // 2:
- return False, f"Top padding is {top}, kernel height is {k_h}"
- if bottom > k_h // 2:
- return False, f"Bottom padding is {bottom}, kernel height is {k_h}"
- if not SupportedOperators.__leading_pad_ok(top, k.stride.y, k_h):
- return False, f"Top padding is {top}, must be {k_h // 2} or multiple of {k.stride.y}"
- if not SupportedOperators.__leading_pad_ok(left, k.stride.x, k_w):
- return False, f"Left padding is {left}, must be {k_w // 2} or multiple of {k.stride.x}"
+ if cons.type.is_avgpool_op():
+ # For average pool, padding works different on the NPU; more restrictions apply
+ for name, pad, k_size in (
+ ("Left", left, k_w),
+ ("Right", right, k_w),
+ ("Top", top, k_h),
+ ("Bottom", bottom, k_h),
+ ):
+ if pad not in (0, k_size // 2):
+ return False, f"{name} padding is {pad}, only 0 or {k_size // 2} are supported"
+ else:
+ if left > k_w // 2:
+ return False, f"Left padding is {left}, kernel width is {k_w}"
+ if right > k_w // 2:
+ return False, f"Right padding is {right}, kernel width is {k_w}"
+ if top > k_h // 2:
+ return False, f"Top padding is {top}, kernel height is {k_h}"
+ if bottom > k_h // 2:
+ return False, f"Bottom padding is {bottom}, kernel height is {k_h}"
+ if not SupportedOperators.__leading_pad_ok(top, k.stride.y, k_h):
+ return False, f"Top padding is {top}, must be {k_h // 2} or multiple of {k.stride.y}"
+ if not SupportedOperators.__leading_pad_ok(left, k.stride.x, k_w):
+ return False, f"Left padding is {left}, must be {k_w // 2} or multiple of {k.stride.x}"
return True, "Pad size is ok"
@staticmethod
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 40b8cd5..285b3ac 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -157,6 +157,53 @@ def test_optimise_pad():
assert pad_op not in op.ifm.ops
+def test_optimise_pad_followed_by_avg_pool():
+ """
+ Tests that the PAD operator is bypassed when followed by a average pool operator,
+ and that the average pool is converted to a depthwise
+ """
+ # Create Pad operation followed by AvgPool
+ quant = testutil.default_quant_params()
+ in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
+ in_tens.quantization = quant
+ pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
+ temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
+ temp_tens.quantization = quant.clone()
+ out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
+ out_tens.quantization = quant.clone()
+
+ pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
+ attrs = {
+ "padding": Padding.VALID,
+ "ksize": [1, 5, 3, 1],
+ "stride_w": 2,
+ "stride_h": 2,
+ "dilation_w_factor": 1,
+ "dilation_h_factor": 1,
+ }
+ attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+ pad_op.run_on_npu = True
+ conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
+ conv2d_op.run_on_npu = True
+ nng = Graph()
+ sg = testutil.create_subgraph([pad_op, conv2d_op])
+ nng.subgraphs.append(sg)
+ arch = testutil.create_arch()
+
+ optimise_pad(conv2d_op, nng, arch)
+
+ op = sg.output_tensors[0].ops[0]
+ assert op.type == Op.DepthwiseConv2DBias
+ assert op.attrs["padding"] == Padding.EXPLICIT
+ assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
+ assert op.ifm.shape == [1, 76, 75, 64]
+ assert pad_op not in op.ifm.ops
+ # Check that bias and weight tensors have been added
+ assert op.bias.shape == [64]
+ print("op.weights:", op.weights)
+ assert op.weights.shape == [5, 3, 1, 64]
+
+
def test_remove_reshape():
"""
Tests that the expected reshape are removed in graph_optimisation
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 3e9724d..6401d29 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -609,14 +609,7 @@ def test_constraint_pad_consumer():
op_consumer = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
op.ofm.consumer_list = [op_consumer]
assert not support.is_operator_supported(op)
- op_consumer = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op_consumer.attrs = {
- "stride_w": 2,
- "stride_h": 2,
- "filter_width": 2,
- "filter_height": 2,
- "padding": Padding.VALID,
- }
+ op_consumer = testutil.create_elemwise_op(Op.Add, "op", [1, 3, 3, 1], [1, 3, 3, 1], [1, 3, 3, 1])
op.ofm.consumer_list = [op_consumer]
assert not support.is_operator_supported(op)
@@ -655,6 +648,55 @@ def test_constraint_leading_pad_size(top, left, kernel_size, expected):
assert support.is_operator_supported(op) == expected
+pad_avg_pool_test_data = [
+ ((3, 3), (1, 1, 1, 1), True),
+ ((2, 4), (1, 2, 1, 2), True),
+ ((5, 3), (2, 1, 2, 1), True),
+ ((5, 3), (0, 1, 2, 1), True),
+ ((5, 3), (2, 0, 2, 1), True),
+ ((5, 3), (2, 1, 0, 1), True),
+ ((5, 3), (2, 1, 0, 1), True),
+ ((4, 4), (2, 2, 2, 2), True),
+ ((4, 4), (1, 2, 2, 2), False),
+ ((4, 4), (2, 1, 2, 2), False),
+ ((4, 4), (2, 2, 1, 2), False),
+ ((4, 4), (2, 2, 2, 1), False),
+]
+
+
+@pytest.mark.parametrize("k_size, padding, expected", pad_avg_pool_test_data)
+def test_pad_followed_by_avg_pool(k_size, padding, expected):
+ # Tests PAD followed by AvgPool
+ k_w, k_h = k_size
+ top, left, bottom, right = padding
+ pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
+ dtype = DataType.int8
+ qp = testutil.default_quant_params()
+ in_shape = [1, 15, 17, 8]
+ out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
+ in0 = Tensor(in_shape, dtype, "in")
+ in0.quantization = qp
+ pad_tensor = create_const_tensor(
+ name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
+ )
+ out = Tensor(out_shape, dtype, "out")
+ out.quantization = qp.clone()
+ op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+ pool_out_tens = Tensor(in_shape, dtype, "output")
+ pool_out_tens.quantization = qp.clone()
+ attrs = {
+ "padding": Padding.VALID,
+ "ksize": [1, k_w, k_h, 1],
+ "stride_w": 1,
+ "stride_h": 1,
+ "dilation_w_factor": 1,
+ "dilation_h_factor": 1,
+ }
+ pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
+ pool_op.add_input_tensor(out)
+ assert support.is_operator_supported(op) == expected
+
+
def create_strided_slice():
# Creates a valid strided slice operator with some valid inputs/outputs
op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])