aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRickard Bolin <rickard.bolin@arm.com>2022-07-04 16:19:16 +0000
committerRickard Bolin <rickard.bolin@arm.com>2022-09-23 09:13:20 +0000
commitfea1516f94cfcbd801124e3fdc4b5f5c4526e15b (patch)
tree92b991244ef535d652d0bb6e875e9e3a289257f5
parentcc219be4ec175645e8457da80d5effbf4324943b (diff)
downloadethos-u-vela-fea1516f94cfcbd801124e3fdc4b5f5c4526e15b.tar.gz
MLBEDSW-6686: Resize bilinear HPC with tile padding
- Added support for Resize Bilinear with half pixel centers for int8 and uint8. - Utilizes the new "TILE" padding mode. - Utilizes ofm stride multipliers and modified tile base offsets to write OFMs interleaved. Signed-off-by: Rickard Bolin <rickard.bolin@arm.com> Change-Id: I37fa77c022a368f05fda0ead75d8696c9205f833
-rw-r--r--SUPPORTED_OPS.md19
-rw-r--r--ethosu/vela/graph_optimiser_util.py8
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py24
-rw-r--r--ethosu/vela/operation.py23
-rw-r--r--ethosu/vela/tensor.py17
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py9
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py141
-rw-r--r--ethosu/vela/tflite_supported_operators.py38
-rw-r--r--ethosu/vela/weight_compressor.py8
9 files changed, 251 insertions, 36 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 6a92e829..36b403ad 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
# Supported Ops
This file was automatically generated by Vela using the `--supported-ops-report` parameter.
-Vela version: `3.5.0`
+Vela version: `3.5.1.dev14+gc22ad76.d20220921`
This file complies with
[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -36,6 +36,7 @@ Please check the supported operator list for your chosen runtime for further inf
| MUL | [Generic](#tflite-generic-constraints), [Specific](#tflite-mul-constraints) |
| PACK | [Generic](#tflite-generic-constraints) |
| PAD | [Generic](#tflite-generic-constraints), [Specific](#tflite-pad-constraints) |
+| PRELU | [Generic](#tflite-generic-constraints) |
| QUANTIZE | [Generic](#tflite-generic-constraints) |
| RELU | [Generic](#tflite-generic-constraints) |
| RELU6 | [Generic](#tflite-generic-constraints) |
@@ -116,7 +117,6 @@ This is a list of constraints that the CONCATENATION operator must satisfy in or
- Axis attribute must be in the range [0, <ofm_dimensions>)
- All Input dimensionalities must match OFM dimensionality
- All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute
-- All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute
- The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute
### TFLite CONV_2D Constraints
@@ -184,7 +184,6 @@ This is a list of constraints that the LEAKY_RELU operator must satisfy in order
- At least one Input's shape must match the OFM's shape
- IFM and OFM data types must match
-- Alpha only allowed to be negative if IFM is int8 or uint8
- Batch size must be 1 for Input tensors with more than 2 dimensions
### TFLite MAXIMUM Constraints
@@ -268,6 +267,7 @@ This is a list of constraints that the RESHAPE operator must satisfy in order to
- Input and output quantisation must match.
- Shape must be constant
+- Reshape on NPU not supported before MEAN operator
### TFLite RESIZE_BILINEAR Constraints
@@ -276,11 +276,12 @@ This is a list of constraints that the RESIZE_BILINEAR operator must satisfy in
- The width and height of the IFM and OFM must match one of the following criteria:
IFM W and H must both be 1
IFM must match OFM
- OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
- OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
+ W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
+ W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
- The size tensor must match the output tensor shape
- Both align_corners and half_pixel_centers can't be True
-- half_pixel_centers are not supported
+- Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8
+- Half_pixel_centers for resize bilinear requires that OFM W and H is 2x IFM W and H
### TFLite RESIZE_NEAREST_NEIGHBOR Constraints
@@ -289,11 +290,11 @@ This is a list of constraints that the RESIZE_NEAREST_NEIGHBOR operator must sat
- The width and height of the IFM and OFM must match one of the following criteria:
IFM W and H must both be 1
IFM must match OFM
- OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
- OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
+ W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
+ W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
- The size tensor must match the output tensor shape
- Both align_corners and half_pixel_centers can't be True
-- half_pixel_centers are not supported
+- Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8
### TFLite SOFTMAX Constraints
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 5e7e1127..b33851a8 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2021-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -111,6 +111,12 @@ def check_format_restrictions(tens, arch):
if _avoid_nhcwb16_for_shapes(tens):
return
+ # Resize bilinear half pixel center implementation requires OFM with linear format to
+ # allow stride modification in H/W dimensions.
+ for op in tens.ops:
+ if op.original_type == Op.ResizeBilinear and op.type == Op.DepthwiseConv2DBias:
+ return
+
for op in tens.consumer_list:
if op.type == Op.ReduceSum and (
tens.dtype == DataType.int32 or arch.accelerator_config == Accelerator.Ethos_U65_512
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 6246b37e..7923e371 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -189,6 +189,7 @@ def create_padding(cmd: NpuStripe, primary_op: Operation, npu_op: NpuBlockOperat
dtype=cmd.ifm_tensor.dtype,
)
top, left, bottom, right = 0, 0, 0, 0
+
return NpuPadding(top=top, left=left, bottom=bottom, right=right)
@@ -297,6 +298,10 @@ def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
"""Checks if quantization should use 0 as zero point"""
if tens.dtype == DataType.int32 and is_ifm_tensor:
return True
+ # Force zero point to 0 for ResizeBilinear when converting to a DepthwiseConv since the reference kernel
+ # will ignore the zero point.
+ if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
+ return True
if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
return False
if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
@@ -352,6 +357,7 @@ def create_feature_map(
box: Box,
arch: ArchitectureFeatures,
op_shape4D: Shape4D,
+ tile_base_offsets: List[int],
stride_multiplier: Optional[List[int]] = None,
) -> NpuFeatureMap:
"""Creates feature map with common fields populated"""
@@ -380,6 +386,8 @@ def create_feature_map(
box.start_coord, box.end_coord, strides, op_shape4D
)
+ for idx, offset in enumerate(tile_base_offsets):
+ addresses[idx] += offset
fm.tiles = NpuTileBox(
height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
)
@@ -475,12 +483,14 @@ def set_common_op_fields(npu_op: NpuBlockOperation, cmd: NpuStripe, arch: Archit
ifm_width = cmd.ps.ifm_shapes[0].width
ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
- npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
+ npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0])
npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
out_block = cmd.ofm_box.get_block()
- npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.ofm_stride_multiplier)
+ npu_op.ofm = create_feature_map(
+ cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.tile_base_offsets_ofm, op.ofm_stride_multiplier
+ )
npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
@@ -559,7 +569,13 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
npu_op.reversed_operands = True
- npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
+ npu_op.ifm2 = create_feature_map(
+ cmd.ifm2_tensor,
+ cmd.ifm2_box,
+ arch,
+ ps.ifm_shapes[1],
+ op.tile_base_offsets_ifm[1],
+ )
npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
if cmd.ifm2_tensor.shape == []:
# scalar
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index e1622049..af2205cd 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -474,7 +474,7 @@ class Operation:
__slots__ = (
"type",
- "original_type",
+ "_original_type",
"name",
"op_index",
"attrs",
@@ -501,12 +501,14 @@ class Operation:
"write_offset",
"write_shape",
"ifm_resampling_mode",
+ "tile_base_offsets_ifm",
+ "tile_base_offsets_ofm",
"ofm_stride_multiplier",
)
def __init__(self, op_type: Op, name: str):
self.type = op_type
- self.original_type = op_type
+ self._original_type = op_type # the original type of the operation. once set this shouldn't be changed
self.name = name
self.attrs: Dict[str, Any] = {}
self.inputs: List[Optional[Tensor]] = []
@@ -546,6 +548,10 @@ class Operation:
# write_offset 0,9,0,0, write_shape 1,1,8,1
self.write_shape: Optional[Shape4D] = None
self.ifm_resampling_mode: resampling_mode = resampling_mode.NONE
+ # ifm (nhwc), ifm2 (nhwc)
+ self.tile_base_offsets_ifm: List[List[int]] = [[0, 0, 0, 0], [0, 0, 0, 0]]
+ # ofm (nhwc)
+ self.tile_base_offsets_ofm: List[int] = [0, 0, 0, 0]
# For interleaved/sparse outputs - stride is multiplied with the stride factor of the corresponding axis
# Order is [C, H, W] - default is no multiplication
self.ofm_stride_multiplier: List[int] = [1, 1, 1]
@@ -553,6 +559,9 @@ class Operation:
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
+ # maintain the original type, in cases where the type was changed to something different
+ res._original_type = self._original_type
+
res.attrs = dict(self.attrs)
res.inputs = list(self.inputs)
res.outputs = list(self.outputs)
@@ -567,11 +576,15 @@ class Operation:
res.op_index = None # not relevant as not part of input network
res.read_offsets = list(self.read_offsets)
res.read_shapes = list(self.read_shapes)
+ res.write_offset = Shape4D(*self.write_offset) if self.write_offset else None
+ res.write_shape = Shape4D(*self.write_shape) if self.write_shape else None
res.rounding_mode = self.rounding_mode
res.explicit_scaling = self.explicit_scaling
res.low_precision_scaling = self.low_precision_scaling
res.rescale = self.rescale
res.ifm_resampling_mode = self.ifm_resampling_mode
+ res.tile_base_offsets_ifm = [_ifm.copy() for _ifm in self.tile_base_offsets_ifm]
+ res.tile_base_offsets_ofm = self.tile_base_offsets_ofm.copy()
res.ofm_stride_multiplier = self.ofm_stride_multiplier.copy()
return res
@@ -581,6 +594,10 @@ class Operation:
__repr__ = __str__
+ @property
+ def original_type(self):
+ return self._original_type
+
def get_kernel_size(self):
weights = self.weights
if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 99970317..9fbd454c 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -213,6 +213,7 @@ class QuantizationParameters:
"max",
"num_bits",
"narrow_range",
+ "next_after",
"scale_f32",
"zero_point",
"quant_min",
@@ -233,6 +234,10 @@ class QuantizationParameters:
self.num_bits = num_bits
self.narrow_range = narrow_range
+ # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
+ # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
+ # no affect on global scaling i.e. the ofm_scale register
+ self.next_after = False
self.scale_f32: Union[float, np.ndarray, None] = None
self.zero_point: Union[int, np.ndarray, None] = None
self.quant_min: Optional[float] = None
@@ -240,12 +245,9 @@ class QuantizationParameters:
self.quant_dim: Optional[int] = None
def __str__(self):
- return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
- self.min,
- self.max,
- self.num_bits,
- self.scale_f32,
- self.zero_point,
+ return (
+ f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
+ f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
)
__repr__ = __str__
@@ -258,6 +260,7 @@ class QuantizationParameters:
res.num_bits = self.num_bits
res.narrow_range = self.narrow_range
+ res.next_after = self.next_after
res.scale_f32 = self.scale_f32
res.zero_point = self.zero_point
res.quant_min = self.quant_min
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 89c27997..3872bdc8 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -383,11 +383,14 @@ def test_constraint_resize_attrs():
def test_constraint_resize_half_pixel_centers():
for resize_op in Op.op_set(Op.is_resize_op):
- # Invalid case - half-pixel centers (not supported)
+ # Half-pixel centers is only supported for resize bilinear
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
op.attrs["half_pixel_centers"] = True
- assert not support.is_operator_supported(op)
+ if resize_op == Op.ResizeBilinear:
+ assert support.is_operator_supported(op)
+ else:
+ assert not support.is_operator_supported(op)
def test_constraint_concat_pass():
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 6b454e3d..27513d3d 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -464,6 +464,143 @@ def convert_resize_to_upscale_and_average_pool(op):
return op
+def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
+ def _compute_interpolation_values(index, input_size, output_size):
+ scale = input_size / output_size
+ scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
+ lower_bound = max(np.floor(scaled_value), 0)
+
+ return scaled_value, lower_bound
+
+ def _compute_kernels(input_height, input_width, output_height, output_width):
+ kernels = []
+ for y in (1, 2):
+ for x in (1, 2):
+ sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
+ sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
+
+ # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
+ # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
+ # top-to-bottom - same as the depthwise convolution strides across each tile
+ kernel = np.zeros((2, 2))
+ kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
+ kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
+ kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
+ kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
+ kernel *= 16
+ kernels.append(kernel)
+
+ return kernels
+
+ def _build_convolutions(op, kernels):
+ dw_op_attrs = {
+ "padding": Padding.TILE,
+ "stride_h": 1,
+ "stride_w": 1,
+ "strides": (1, 1, 1, 1),
+ "depth_multiplier": 1,
+ "channel_multiplier": 1,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "dilation": (1, 1, 1, 1),
+ }
+ ifm = op.ifm
+ ofm = op.ofm
+ ofm.ops = []
+ elem_size = 2 if ofm.dtype == DataType.int16 else 1
+
+ n, h, w, c = ifm.shape
+ _, _, ow, _ = ofm.shape
+
+ intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
+ intermediate_tens.quantization = op.outputs[0].quantization.clone()
+ avgpool_op = op
+ avgpool_op.name = "rb_init_avgpool"
+ avgpool_op.type = Op.AvgPool
+ avgpool_op.attrs["padding"] = Padding.VALID
+ avgpool_op.attrs["stride_w"] = 1
+ avgpool_op.attrs["stride_h"] = 1
+ avgpool_op.attrs["filter_width"] = 1
+ avgpool_op.attrs["filter_height"] = 1
+ avgpool_op.attrs["strides"] = [1, 1, 1, 1]
+ avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
+
+ avgpool_op.add_input_tensor(ifm)
+ avgpool_op.set_output_tensor(intermediate_tens)
+ avgpool_op.set_ifm_ofm_shapes()
+
+ dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
+ dw_conv._original_type = Op.ResizeBilinear
+ dw_conv.write_shape = Shape4D(n, h, w, c)
+ dw_conv.write_offset = Shape4D(0, 0, 0, 0)
+
+ # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
+ # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
+ # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
+ # values to be incorrectly rounded
+ ofm.quantization.next_after = True
+ dw_conv.rounding_mode = NpuRoundingMode.NATURAL
+
+ # Double height and width stride to write the output of each of the four depthwise convolutions below
+ # interleaved with each other when combined with OFM tile base offsets.
+ dw_conv.ofm_stride_multiplier = [1, 2, 2] # C/H/W
+
+ # Choose tile padding direction - pad by 1 with edge values in two direction.
+ # For example, TL (top left) will pad top and left in H/W-plane in all channels.
+ directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]] # TL, TR, BL, BR
+ for i in (0, 1):
+ for j in (0, 1):
+ index = i * 2 + j
+ dw_conv.name = f"depthwise_conv_{index}"
+ dw_op_attrs["explicit_padding"] = directions[index]
+ dw_conv.attrs.update(dw_op_attrs)
+
+ # This will offset the start of the write by modifying the Tile 0 base address
+ dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
+
+ ofm.ops.append(dw_conv)
+ dw_conv.outputs = [ofm]
+
+ kernel = kernels[index]
+ shape = [2, 2, 1, c]
+ kernel = np.dstack([kernel] * c)
+
+ quant = QuantizationParameters()
+ quant.zero_point = 0
+ quant.scale_f32 = 1.0 / 16
+
+ dw_conv.inputs = []
+ dw_conv.add_input_tensor(intermediate_tens)
+ dw_conv.add_input_tensor(
+ create_const_tensor(
+ "weights",
+ shape,
+ intermediate_tens.dtype,
+ np.array(kernel).reshape(shape),
+ value_dtype=np.int8,
+ quantization=quant,
+ ),
+ )
+
+ # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
+ # need to append the bias tensor as resize ops only have 2 inputs
+ assert len(dw_conv.inputs) == 2
+ dw_conv.inputs.append(None)
+ fixup_bias_tensors(dw_conv, None, None)
+
+ dw_conv.set_ifm_ofm_shapes()
+ dw_conv = dw_conv.clone(f"_{index}")
+ return op
+
+ _, input_height, input_width, _ = op.ifm.shape
+ _, output_height, output_width, _ = op.ofm.shape
+
+ kernels = _compute_kernels(input_height, input_width, output_height, output_width)
+ op = _build_convolutions(op, kernels)
+
+ return op
+
+
def fixup_resize(op, arch, nng):
if op.type.is_resize_op() and op.run_on_npu:
if op.ifm_shapes[0] == op.ofm_shapes[0]:
@@ -472,6 +609,8 @@ def fixup_resize(op, arch, nng):
op.type = Op.Identity
elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
convert_resize_1x1_to_add(op)
+ elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
+ convert_resizebilinear_to_depthwise_convolutions(op)
else:
convert_resize_to_upscale_and_average_pool(op)
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index be86e9a3..9aa174de 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -255,6 +255,11 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_attrs)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_half_pixel_centers)
+ # Resize Bilinear specific checks:
+ self.specific_constraints[Op.ResizeBilinear].append(
+ TFLiteSupportedOperators.constraint_resizebi_half_pixel_centers_dims
+ )
+
# Vector Product specific checks:
for op_type in TFLiteSupportedOperators.fc_vector_products:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type)
@@ -602,8 +607,8 @@ class TFLiteSupportedOperators:
"""The width and height of the IFM and OFM must match one of the following criteria:
IFM W and H must both be 1
IFM must match OFM
- OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
- OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False"""
+ W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
+ W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False"""
# Easier to start with False condition as very few cases result in a supported resize
valid = False
ifm_shape = op.ifm.shape
@@ -661,11 +666,30 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_resize_half_pixel_centers(op):
- "half_pixel_centers are not supported"
- valid = True
- if op.attrs.get("half_pixel_centers", False):
+ """Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8"""
+ valid = op.ifm.dtype in (DataType.int8, DataType.uint8)
+ half_pixel_centers = op.attrs.get("half_pixel_centers", False)
+ if half_pixel_centers and op.type != Op.ResizeBilinear:
+ valid = False
+ return valid, f"Op type={op.type}, ifm dtype={op.ifm.dtype} and half_pixel_centers={half_pixel_centers}"
+
+ @staticmethod
+ def constraint_resizebi_half_pixel_centers_dims(op):
+ """Half_pixel_centers for resize bilinear requires that OFM W and H is 2x IFM W and H"""
+ half_pixel_centers = op.attrs.get("half_pixel_centers", False)
+ if not half_pixel_centers:
+ valid = True
+ elif len(op.ifm.shape) >= 3:
+ ifm_h, ifm_w = op.ifm.shape[-3:-1]
+ ofm_h, ofm_w = op.ofm.shape[-3:-1]
+ valid = ofm_h / ifm_h == 2 and ofm_w / ifm_w == 2
+ else:
+ # Unexpected IFM shape
valid = False
- return valid, f"Op has half_pixel_centers set to {not valid}."
+ return (
+ valid,
+ f"Op has ifm_shape={op.ifm.shape}, ofm_shape={op.ofm.shape} and half_pixel_centers={half_pixel_centers}",
+ )
@staticmethod
def constraint_pad_shape(op):
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index db225fb6..6f9467ec 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -281,6 +281,12 @@ def _prepare_scale_and_bias(arch, tens, rescale_for_faf, explicit_scaling):
else:
quantised_scales = [quantise_scale(scale) for scale in scales]
+ # Check the output quantisation to see if the scale value needs increasing to the next one
+ if first_consumer_op.get_output_quantization().next_after:
+ for i, quant_scale in enumerate(quantised_scales):
+ q_scale, q_shift = quant_scale
+ quantised_scales[i] = (q_scale + 1, q_shift)
+
# If only 1 quantised scale is used, repeat that value for the length of the biases
if len(quantised_scales) == 1:
quantised_scales = [quantised_scales[0]] * len(biases)