aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-04-13 18:54:47 +0200
committerJohan Alfven <johan.alfven@arm.com>2023-04-19 12:26:19 +0200
commit7b3008a905d2a5122e21f945db7d2a2132473c53 (patch)
tree824b7a1b6b4ed4e5e382901e60331ff8e2159d5f
parent0ac0804e76e098695ee2b8a9e24e2f0a1efc324f (diff)
downloadethos-u-vela-7b3008a905d2a5122e21f945db7d2a2132473c53.tar.gz
MLBEDSW-7487: Updated implementation for the Mean op
- Latest reference has changed implementation for the Mean op and now only contain one variant. - Updated Vela implementation to match reference. The full sum is first calculated and then divided by the numbers of elements. - Removed the avg pool variant and test case. - Updated SUPPORTED_OPS.md Change-Id: I4275e36e3697fa837f119f2cefd7c0ff94231605 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r--SUPPORTED_OPS.md10
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py7
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py92
-rw-r--r--ethosu/vela/tflite_supported_operators.py31
4 files changed, 65 insertions, 75 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index f641d3f2..a870c5aa 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
# Supported Ops
This file was automatically generated by Vela using the `--supported-ops-report` parameter.
-Vela version: `3.7.1.dev15+g2b5f66e`
+Vela version: `3.7.1.dev16+g1f9a4df.d20230417`
This file complies with
[**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -221,13 +221,9 @@ This is a list of constraints that the MEAN operator must satisfy in order to be
- IFM must be int8 or uint8
- Input tensor must be at least 2D
- Axis indices must correspond to height and width axes
-- Product of height and width must be no greater than 65536
-- Product of height and width must be no greater than 4096 when:
- IFM and OFM have different scale or zero point; or
- 'keep_dims' is True
+- Product of height and width must be no greater than 4096
- For single axis averages across the height dimension:
- IFM height must be no greater than 256 if the IFM and OFM scale and zero point match; otherwise
- IFM height must be no greater than 64 if the IFM and OFM scale or zero point do not match
+ IFM height must be no greater than 64
### TFLite MINIMUM Constraints
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 04f10e9a..74dd3bf2 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -618,13 +618,6 @@ def test_mean_hw_product():
assert not support.is_operator_supported(op)
-def test_mean_hw_product_avgpool():
- op = create_mean([1, 200, 200, 16], [1, 16], [1, 2], DataType.uint8, {"keep_dims": False})
- assert support.is_operator_supported(op)
- op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
- assert not support.is_operator_supported(op)
-
-
def test_lstm_support():
# Test valid configuration
op = testutil.create_lstm_op(3, 12, 24, 20, DataType.int8)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 478d0189..393a8323 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -45,6 +45,7 @@ from .lstm import Lstm
from .numeric_util import clamp_sigmoid
from .numeric_util import full_shape
from .numeric_util import round_away_zero
+from .numeric_util import round_down_log2
from .operation import create_activation_function
from .operation import ExplicitScaling
from .operation import NpuBlockType
@@ -1827,22 +1828,7 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
# Set IFM/OFM shapes after changing op type
op.set_ifm_ofm_shapes()
- weight_scale, bias = 1, 0
ofmq, ifmq = op.ofm.quantization, inp.quantization
- if ifmq.is_scaling_equal(ofmq):
- # Here we can just use a simple AvgPool with truncating rounding,
- # as we're emulating simple integer division.
- op.rounding_mode = NpuRoundingMode.TRUNCATE
- op.type = Op.AvgPool
- op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
- else:
- op.rounding_mode = NpuRoundingMode.NATURAL
- weight_scale = 1 / (h * w)
- # Input zero point is adjusted after mean calculation, so we emulate that with a bias
- bias = -ifmq.zero_point * h * w
- fiq = ifmq.clone()
- fiq.zero_point = 0
- op.forced_input_quantization = fiq
# Change dimensions to 4
def extend_dims(dim, in_shape):
@@ -1867,28 +1853,18 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
# If height is greater than max kernel height, reshape from HxW to 1x(HxW)
weight_shape = None
- if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
+ if h > 64:
# This can only happen and be done for multiple axes, and
- # h * w <= 256 for DepthwiseConv2DBias
- # h * w <= 4096 for AvgPool
+ # h * w <= 4096 for DepthwiseConv2DBias
# which is checked in supported ops
shape = [shape[0], 1, h * w, shape[3]]
op.ifm_shapes[0] = Shape4D(shape)
weight_shape = [1, h * w, shape[3], shape[0]]
- if h > 256 and op.type == Op.AvgPool:
- op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
- # If the AvgPool version is used, we don't need to do anything else
- if op.type == Op.AvgPool:
- DebugDatabase.add_optimised(op, op)
- return op
-
- # Make unit weight tensor quantization
- weight_quant = ifmq.clone()
- weight_quant.min = 0
- weight_quant.max = 255
- weight_quant.scale_f32 = weight_scale
- weight_quant.zero_point = 0
+ op.rounding_mode = NpuRoundingMode.NATURAL
+ identity_quant = QuantizationParameters(scale_f32=1.0, zero_point=0)
+ op.forced_input_quantization = identity_quant
+ op.forced_output_quantization = identity_quant
if weight_shape is None:
# Set weight shape to [H,W,C,B]
@@ -1901,17 +1877,65 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
weight_shape,
inp.dtype,
np.ones(weight_shape),
- quantization=weight_quant,
+ quantization=identity_quant,
),
1,
)
op.weights.values = np.reshape(op.inputs[1].values, weight_shape)
- # Add bias tensor
+ # Input zero point is adjusted after the sum calculation, so we emulate that with a bias
+ bias = -ifmq.zero_point * h * w
bias_shape = [shape[-1]]
- op.inputs.append(create_const_tensor("bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
+ op.inputs.append(create_const_tensor(op.name + "_bias", bias_shape, DataType.int32, np.ones(bias_shape) * bias))
DebugDatabase.add_optimised(op, op)
+ # Multiply sum with 1/num_elements_in_axis to get the mean
+ intermediate = op.ofm.clone(suffix="_intermediate", set_unique=True)
+ intermediate.dtype = DataType.int32
+ mul_op = Operation(Op.Mul, op.name + "_mul")
+ mul_op.add_input_tensor(intermediate)
+ mul_op.set_output_tensor(op.ofm)
+ mul_op.forced_input_quantization = identity_quant
+
+ # The multiplier is calculated in the same way as in the reference,
+ # clamping the shift value at the price of some precision loss.
+ num_elements_in_axis = int(h * w)
+ output_multiplier, output_shift_vela = quantise_scale(np.double(ifmq.scale_f32) / np.double(ofmq.scale_f32))
+
+ # Convert to reference representation shift value
+ output_shift = 31 - output_shift_vela
+
+ # Reference calculation
+ # round_down_log2 same as 63 - CountLeadingZeros(num_elements_in_axis)
+ shift = round_down_log2(num_elements_in_axis)
+ shift = min(shift, 32)
+ shift = min(shift, 31 + output_shift)
+ output_multiplier = (output_multiplier << shift) // num_elements_in_axis
+ output_shift = output_shift - shift
+
+ # Convert to vela representation shift
+ output_shift_vela = 31 - output_shift
+
+ # For int32 scaling is not supported so instead multiply with the scale
+ # intermediate * scale -> round and shift.
+ scalar = create_const_tensor(
+ op.name + "_scalar", [1, 1, 1, 1], DataType.int32, [output_multiplier], quantization=identity_quant
+ )
+ mul_op.add_input_tensor(scalar)
+ mul_op.set_ifm_ofm_shapes()
+
+ # Reference using TFL rounding for the multiply
+ mul_op.rounding_mode = NpuRoundingMode.TFL
+
+ # Need to use explicit scaling to get the wanted shift
+ mul_op.explicit_scaling = ExplicitScaling(False, [output_shift_vela], [1])
+
+ mul_op.activation = op.activation
+ op.activation = None
+ op.set_output_tensor(intermediate)
+ op.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, mul_op)
+
return op
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 457c35eb..9ace3219 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -191,7 +191,6 @@ class TFLiteSupportedOperators:
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
mean_kernel_product = 64 * 64
- mean_kernel_product_avgpool = 256 * 256
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -309,7 +308,6 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_type)
# Mean specific checks:
- self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_single_axis)
@@ -809,26 +807,9 @@ class TFLiteSupportedOperators:
return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}"
@classmethod
- @docstring_format_args([mean_kernel_product_avgpool])
- def constraint_mean_height_width_product_avgpool(cls, op):
- """Product of height and width must be no greater than {}"""
- shape = op.inputs[0].shape
- hi = 0 if len(shape) < 4 else 1
- h, w = shape[hi : hi + 2]
- max_prod = cls.mean_kernel_product_avgpool
- return h * w <= max_prod, f"Product of height and width is {h * w}"
-
- @classmethod
@docstring_format_args([mean_kernel_product])
def constraint_mean_height_width_product(cls, op):
- """Product of height and width must be no greater than {} when:
- IFM and OFM have different scale or zero point; or
- 'keep_dims' is True"""
- ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
- keep_dims = op.attrs.get("keep_dims")
- # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
- if not keep_dims and ifmq.scale_f32 == ofmq.scale_f32 and ifmq.zero_point == ofmq.zero_point:
- return True, ""
+ """Product of height and width must be no greater than {}"""
shape = op.inputs[0].shape
hi = 0 if len(shape) < 4 else 1
h, w = shape[hi : hi + 2]
@@ -836,11 +817,10 @@ class TFLiteSupportedOperators:
return h * w <= max_prod, f"Product of height and width is {h * w}"
@classmethod
- @docstring_format_args([filter_height_range[1], dilated_height_range[1]])
+ @docstring_format_args([dilated_height_range[1]])
def constraint_mean_height_single_axis(cls, op):
"""For single axis averages across the height dimension:
- IFM height must be no greater than {} if the IFM and OFM scale and zero point match; otherwise
- IFM height must be no greater than {} if the IFM and OFM scale or zero point do not match"""
+ IFM height must be no greater than {}"""
inp, axis = op.inputs
if axis.shape == [] or axis.shape[0] == 1: # single axis
axis = int(axis.values) if len(axis.shape) == 0 else int(axis.values[0])
@@ -859,10 +839,7 @@ class TFLiteSupportedOperators:
h = shape[axis]
ifm, ofm = op.get_ifm_ofm()
- if check_quantized_tens_scaling_equal(ifm, ofm):
- return h <= cls.filter_height_range[1], f"Height is {h}, IFM and OFM quantizations match"
- else:
- return h <= cls.dilated_height_range[1], f"Height is {h}, IFM and OFM quantizations do not match"
+ return h <= cls.dilated_height_range[1], f"Height is {h}"
@staticmethod
def constraint_reshape_shape_constant(op):