aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2023-05-16 22:39:14 +0100
committertim.hall <tim.hall@arm.com>2023-05-17 11:05:57 +0000
commit5ff4cd12898f44228288a7969b52dff97be30cb2 (patch)
tree1c8068c02254d5479706e41379bbd1f8c7b33205
parent0426fe9de82e0f6b339edbd2bec78a5d322fb05f (diff)
downloadethos-u-vela-5ff4cd12898f44228288a7969b52dff97be30cb2.tar.gz
MLBEDSW-7223: Fusing Pad and AvgPool causes diff
- Fixed an issue with the fusing of PAD and AVERAGE_POOL_2D whereby the rounding away from zero didn't work because it requires the zero point to be at zero but the input padding required it to be set to the desired zero point. This affected both int8 and int16. The solution was to remove it by using the bias prior to the scaling - Refactored the rounding away from zero mode Change-Id: I8f2df69df06d2a9722315c346646e5a901cb2c3b Signed-off-by: Tim Hall <tim.hall@arm.com>
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py30
-rw-r--r--ethosu/vela/operation.py45
-rw-r--r--ethosu/vela/softmax.py6
-rw-r--r--ethosu/vela/tensor.py8
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py44
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py8
-rw-r--r--ethosu/vela/weight_compressor.py5
7 files changed, 107 insertions, 39 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 55b4473..9526bd5 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -62,6 +62,7 @@ from .operation import NpuBlockType
from .operation import Op
from .operation import Operation
from .operation import Padding
+from .operation import RoundingMode
from .register_command_stream_generator import generate_command_stream
from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
from .register_command_stream_util import to_npu_kernel
@@ -113,6 +114,14 @@ resampling_mode_inv_map = {
}
+rounding_mode_map = {
+ RoundingMode.TFLite: NpuRoundingMode.TFL,
+ RoundingMode.ToZero: NpuRoundingMode.TRUNCATE,
+ RoundingMode.HalfUp: NpuRoundingMode.NATURAL,
+ RoundingMode.AwayZero: NpuRoundingMode.NATURAL,
+}
+
+
def ifm_ifm2_correct_order(ifm_shape: Shape4D, ifm2_shape: Shape4D) -> bool:
if ifm_shape is None:
@@ -146,7 +155,7 @@ def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
):
rounding_mode = NpuRoundingMode.NATURAL
if op.rounding_mode is not None:
- rounding_mode = op.rounding_mode
+ rounding_mode = rounding_mode_map[op.rounding_mode]
return rounding_mode
@@ -298,10 +307,21 @@ def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
"""Checks if quantization should use 0 as zero point"""
if tens.dtype == DataType.int32 and is_ifm_tensor:
return True
- # Force zero point to 0 for ResizeBilinear when converting to a DepthwiseConv since the reference kernel
- # will ignore the zero point.
- if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
- return True
+ if ps.primary_op.rounding_mode == RoundingMode.AwayZero:
+ if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
+ # Force zero point to 0 for ResizeBilinear operators converted to a DepthwiseConv with rounding away from
+ # zero. This is because the reference kernel ignores the zero points.
+ return True
+ if (
+ not is_ifm_tensor
+ and ps.primary_op.original_type == Op.AvgPool
+ and ps.primary_op.attrs.get("padding", None) == Padding.EXPLICIT
+ and ps.primary_op.type == Op.DepthwiseConv2DBias
+ ):
+ # Force zero point to 0 for the OFM of AvgPool operators that have been combined with a previous PAD
+ # operator and converted to a DepthwiseConv with rounding away from zero. This is because the zero point
+ # will already have been applied in the Bias.
+ return True
if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
return False
if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 161b17f..52f06cf 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -21,6 +21,7 @@ from __future__ import annotations
import copy
from collections import namedtuple
+from enum import auto
from enum import Enum
from typing import Any
from typing import Dict
@@ -29,7 +30,6 @@ from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
-from .api import NpuRoundingMode
from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .numeric_util import full_shape
@@ -44,6 +44,13 @@ PointXY = namedtuple("PointXY", "x y")
PointXYZ = namedtuple("PointXYZ", "x y z")
+class RoundingMode(Enum):
+ TFLite = auto() # Round like TensorFlow Lite
+ ToZero = auto() # Round towards zero (truncate)
+ HalfUp = auto() # Round to nearest with x.5 rounded up towards positive infinity (natural)
+ AwayZero = auto() # Round away from zero (towards infinity)
+
+
class NpuBlockType(Enum):
Default = 0
ConvolutionMxN = 1
@@ -491,7 +498,7 @@ class Operation:
"rescale",
"read_offsets",
"read_shapes",
- "rounding_mode",
+ "_rounding_mode",
"explicit_scaling",
"write_offset",
"write_shape",
@@ -528,7 +535,7 @@ class Operation:
self.ofm_shapes: List[Shape4D] = []
self.read_offsets: List[Optional[Shape4D]] = [None, None] # offset for [ifm, ifm2]
self.read_shapes: List[Optional[Shape4D]] = [None, None] # read shape for [ifm, ifm2]
- self.rounding_mode: Optional[NpuRoundingMode] = None
+ self._rounding_mode: Optional[RoundingMode] = None
# Rescale op in TOSA supplies explicit multiplier and shift values
self.explicit_scaling: Optional[ExplicitScaling] = None
# Write offset, for operations that only produce a part of the OFM
@@ -587,6 +594,38 @@ class Operation:
return self._original_type
@property
+ def rounding_mode(self):
+ return self._rounding_mode
+
+ @rounding_mode.setter
+ def rounding_mode(self, mode: RoundingMode):
+ # All rounding modes are supported by all operators with the exception of rounding away from zero (see comment
+ # below)
+ is_supported = True
+ if mode == RoundingMode.AwayZero:
+ # Rounding away from zero does not have direct hardware support and so the compiler implements it indirectly
+ # in different ways. The exact process depends upon the operator type and not all operators are supported.
+ # Basically, rounding away from zero works by adjusting the accumulated value by a "small" amount before
+ # rounding up with the addition of a half (natural rounding). This "small" amount should be big enough to
+ # cause x.5 to be rounded correctly but small enough that smaller values are not incorrectly rounded. This
+ # is done by slightly adjusting the scale and shift on the ofm tensor using the scale and bias tensor,
+ # it has no affect on global scaling (i.e. the ofm_scale register). In addition, the zero points of the
+ # input and/or output tensors may also require forcing to zero but the exact behaviour also depends upon the
+ # corresponding optimisation performed in graph_optimisation.py where the rounding mode is set
+ is_supported = False
+ if self.original_type == Op.ResizeBilinear and self.type == Op.DepthwiseConv2DBias:
+ is_supported = True
+ if self.original_type == Op.AvgPool and self.type == Op.DepthwiseConv2DBias:
+ is_supported = True
+
+ if is_supported:
+ self._rounding_mode = mode
+ else:
+ assert (
+ False
+ ), f"Setting rounding mode = {mode} on {self.original_type} operator '{self.name}' is not supported."
+
+ @property
def type_changed(self):
return self.type != self.original_type
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 5a06c1b..8f30fa1 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -24,13 +24,13 @@ import numpy as np
from . import fp_math
from . import scaling
-from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
from .operation import ActivationFunction
from .operation import ExplicitScaling
from .operation import Op
from .operation import Operation
+from .operation import RoundingMode
from .operation_util import create_add
from .operation_util import create_clz
from .operation_util import create_depthwise_maxpool
@@ -281,7 +281,7 @@ class SoftMax:
name = f"{self.op.name}_shr{pass_number}"
shift = create_const_tensor(f"{name}_const", [1, 1, 1, 1], DataType.int32, [12], quantization=no_scale_quant)
shr_op = create_shr(name, ifm_exp, shift, no_scale_quant, activation)
- shr_op.rounding_mode = NpuRoundingMode.NATURAL
+ shr_op.rounding_mode = RoundingMode.HalfUp
rescaled_exp = add_op_get_ofm(shr_op)
# PASS 3 - Reduce sum
@@ -443,7 +443,7 @@ class SoftMax:
# PASS 30 - SHR
shr30_op = Operation(Op.SHR, f"{self.op.name}_shr{pass_number}")
- shr30_op.rounding_mode = NpuRoundingMode.NATURAL
+ shr30_op.rounding_mode = RoundingMode.HalfUp
shr30_op.add_input_tensor(scaled_exp)
shr30_op.add_input_tensor(right_shift)
shr30_op.set_output_tensor(ofm)
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 8f68585..1e4ea11 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -215,7 +215,6 @@ class QuantizationParameters:
"max",
"num_bits",
"narrow_range",
- "next_after",
"scale_f32",
"zero_point",
"quant_min",
@@ -238,10 +237,6 @@ class QuantizationParameters:
self.num_bits = num_bits
self.narrow_range = narrow_range
- # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
- # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
- # no affect on global scaling i.e. the ofm_scale register
- self.next_after = False
self.scale_f32: Union[float, np.ndarray, None] = scale_f32
self.zero_point: Union[int, np.ndarray, None] = zero_point
self.quant_min: Optional[float] = None
@@ -251,7 +246,7 @@ class QuantizationParameters:
def __str__(self):
return (
f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
- f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
+ f"scale={self.scale_f32}, zero_point={self.zero_point}>"
)
__repr__ = __str__
@@ -264,7 +259,6 @@ class QuantizationParameters:
res.num_bits = self.num_bits
res.narrow_range = self.narrow_range
- res.next_after = self.next_after
res.scale_f32 = self.scale_f32
res.zero_point = self.zero_point
res.quant_min = self.quant_min
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index f68e0cf..daaca8d 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -27,7 +27,6 @@ import numpy as np
from . import fp_math
from . import rewrite_graph
from . import scaling
-from .api import NpuRoundingMode
from .data_type import BaseType
from .data_type import DataType
from .debug_database import DebugDatabase
@@ -56,6 +55,7 @@ from .operation import NpuBlockType
from .operation import Op
from .operation import Operation
from .operation import Padding
+from .operation import RoundingMode
from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
from .operation_util import create_cast_op
@@ -295,7 +295,7 @@ def convert_resize_1x1_to_add(op):
return op
-# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
+# Convert ResizeNearestNeighbor with align corners to a depthwise convolution. The IFM will already have been upscaled
# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
# to select the appropriate nearest neighbor value
def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
@@ -314,7 +314,7 @@ def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
"dilation": (1, 1, 1, 1),
}
- # change resizebilinear to depthwise
+ # change ResizeNearestNeighbor to Depthwise
op.type = Op.DepthwiseConv2DBias
op.attrs.update(dw_op_attrs)
op.set_input_tensor(ifm, 0) # ifm tensor index
@@ -695,12 +695,8 @@ def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True
dw_conv.write_shape = Shape4D(n, h, w, c)
dw_conv.write_offset = Shape4D(0, 0, 0, 0)
- # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
- # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
- # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
- # values to be incorrectly rounded
- ofm.quantization.next_after = True
- dw_conv.rounding_mode = NpuRoundingMode.NATURAL
+ # Resize bilinear requires rounding away from zero
+ dw_conv.rounding_mode = RoundingMode.AwayZero
# Double height and width stride to write the output of each of the four depthwise convolutions below
# interleaved with each other when combined with OFM tile base offsets.
@@ -1730,12 +1726,30 @@ def replace_pad_by_hw_pad(op: Operation, arch, nng):
op.inputs = []
op.add_input_tensor(ifm)
op.add_input_tensor(weight_tens)
- # Add bias tensor, all biases set to 0
- op.inputs.append(None)
- fixup_bias_tensors(op, arch, nng, DataType.int32)
+
+ if op.ifm.dtype == DataType.uint8:
+ op.rounding_mode = RoundingMode.HalfUp
+
+ # Add bias tensor, all biases set to 0
+ op.inputs.append(None)
+ fixup_bias_tensors(op, arch, nng, DataType.int32)
+
+ else:
+ op.rounding_mode = RoundingMode.AwayZero
+
+ # The DepthwiseConv needs to be performed with the IFM zero point set appropriately so that the correct
+ # pad values are used. However, in order to use the rounding away from zero mode the zero point needs to
+ # have been removed so that the zero point is at zero. This is done by adding a kernel sized amount of
+ # the zero point as a bias. The datatype of the bias needs to be set to int32, even for an int16 IFM,
+ # because this will cause full precision scaling to be used (see weight compression). Finally, the OFM
+ # zero point will need forcing to zero (as it has already been removed)
+ nr_biases = op.inputs[1].shape[-1]
+ bias_values = [op.ifm.quantization.zero_point * k_h * k_w] * nr_biases
+ bias_tensor = create_const_tensor(op.name + "_bias", [nr_biases], DataType.int32, bias_values)
+ op.add_input_tensor(bias_tensor)
+
# Add other inputs
op.inputs.extend(other_inputs)
- op.rounding_mode = NpuRoundingMode.NATURAL
# Bypass the PAD operator
op.set_input_tensor(pad_op.ifm, 0)
@@ -1946,7 +1960,7 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
# Set weight shape to [H,W,C,B]
weight_shape = [h, w, shape[3], shape[0]]
- op.rounding_mode = NpuRoundingMode.NATURAL
+ op.rounding_mode = RoundingMode.HalfUp
identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
op.forced_input_quantization = identity_quant
op.forced_output_quantization = identity_quant
@@ -2016,7 +2030,7 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
mul_op.set_ifm_ofm_shapes()
# Reference using TFL rounding for the multiply
- mul_op.rounding_mode = NpuRoundingMode.TFL
+ mul_op.rounding_mode = RoundingMode.TFLite
# Need to use explicit scaling to get the wanted shift
mul_op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index b347414..df6b575 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -19,7 +19,6 @@
import numpy as np
from . import rewrite_graph
-from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
from .graph_optimiser_util import bypass_memory_only_ops
@@ -32,6 +31,7 @@ from .graph_optimiser_util import set_tensor_equivalence
from .lut import convert_to_lut
from .operation import ExplicitScaling
from .operation import Op
+from .operation import RoundingMode
from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
from .operation_util import create_pad_nop
@@ -115,7 +115,7 @@ def calc_scaling_avgpool(op, arch, nng):
multiplier.append(numerator // kernel_wh)
shift.append(30 + k)
- op.rounding_mode = NpuRoundingMode.NATURAL
+ op.rounding_mode = RoundingMode.HalfUp
op.explicit_scaling = ExplicitScaling(False, shift, multiplier)
return op
@@ -399,9 +399,9 @@ def rewrite_rescale(op, arch, nng):
explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
if double_round and scale32:
- rounding_mode = NpuRoundingMode.TFL
+ rounding_mode = RoundingMode.TFLite
else:
- rounding_mode = NpuRoundingMode.NATURAL
+ rounding_mode = RoundingMode.HalfUp
if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
assert len(multiplier) == len(shift) == len(prev_op.bias.values)
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index e4779bf..50ae26c 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -32,6 +32,7 @@ from .errors import UnsupportedFeatureError
from .numeric_util import round_up
from .operation import NpuBlockType
from .operation import Op
+from .operation import RoundingMode
from .scaling import quantise_scale
from .scaling import reduced_quantise_scale
from .tensor import QuantizationParameters
@@ -303,8 +304,8 @@ def _prepare_scale_and_bias(arch, tens, explicit_scaling):
else:
quantised_scales = [quantise_scale(scale) for scale in scales]
- # Check the output quantisation to see if the scale value needs increasing to the next one
- if _get_output_quantization(first_consumer_op).next_after:
+ # Rounding away from zero requires the "next after" floating point value to be set on the output quantisation
+ if first_consumer_op.rounding_mode == RoundingMode.AwayZero:
for i, quant_scale in enumerate(quantised_scales):
q_scale, q_shift = quant_scale
quantised_scales[i] = (q_scale + 1, q_shift)