aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2021-03-26 10:53:28 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-04-07 10:50:39 +0000
commit95b279f1454d58a93238851cb5ff394c7782ad32 (patch)
treeeb2e8f4db229f0581a894084c75f44b55877a05b
parentfe368bc231fb680ebfa48e2c35e92dec5639df5e (diff)
downloadethos-u-vela-95b279f1454d58a93238851cb5ff394c7782ad32.tar.gz
MEAN implementation changed to Average Pool
This is a small commit which changes one of the four MEAN implementations to a simpler one, using an AvgPool instead of a DepthwiseConv. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: I9e8af071e8b820796577ee4792b4812a1212602b
-rw-r--r--SUPPORTED_OPS.md4
-rw-r--r--ethosu/vela/graph_optimiser.py25
-rw-r--r--ethosu/vela/supported_operators.py23
-rw-r--r--ethosu/vela/test/test_supported_operators.py9
4 files changed, 47 insertions, 14 deletions
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 1ad65c63..013cad27 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -199,7 +199,9 @@ This is a list of constraints that the MEAN operator must satisfy in order to be
- IFM must be int8 or uint8
- Input tensor must be at least 2D
- Axis indices must correspond to height and width axes
-- Product of height and width can be at most 4096
+- Product of height and width can be at most 65536
+- Product of height and width can be at most 4096 when IFM and OFM have different scale or zero point,
+ or keep_dims is True
- Product of IFM height and width can be at most 256 when the following are true:
IFM dimensions are 4,
Axis indices are 1 and 2,
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index bea22a23..56932dbe 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1382,7 +1382,7 @@ def fixup_bias_tensors(op, arch, nng):
return op
-def convert_mean_to_depthwise_conv(op, arch, nng):
+def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
if op.type == Op.Mean and op.run_on_npu:
keep_dims = op.attrs.get("keep_dims", False)
inp, axis = op.inputs
@@ -1422,8 +1422,6 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
)
# Change op type
op.type = Op.DepthwiseConv2DBias
- # Add None bias tensor
- op.inputs.append(None)
# Set IFM/OFM shapes after changing op type
op.set_ifm_ofm_shapes()
@@ -1509,14 +1507,11 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
op.set_output_tensor(intermediate)
op.set_ifm_ofm_shapes()
elif ifmq.zero_point == ofmq.zero_point and ifmq.scale_f32 == ofmq.scale_f32:
+ # Here we can just use a simple AvgPool with truncating rounding,
+ # as we're emulating simple integer division.
op.rounding_mode = NpuRoundingMode.TRUNCATE
- weight_scale = 1 / (h * w)
- foq = ofmq.clone()
- foq.zero_point = 0
- op.forced_output_quantization = foq
- fiq = ifmq.clone()
- fiq.zero_point = 0
- op.forced_input_quantization = fiq
+ op.type = Op.AvgPool
+ op.attrs.update({"ksize": (1, h, w, 1), "filter_height": h, "filter_width": w})
else:
op.rounding_mode = NpuRoundingMode.NATURAL
weight_scale = 1 / (h * w)
@@ -1537,6 +1532,12 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
shape = [shape[0], 1, h * w, shape[3]]
op.ifm_shapes[0] = Shape4D(shape)
inp.avoid_NHCWB16 = True
+ if h > 256 and op.type == Op.AvgPool:
+ op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
+
+ # If the AvgPool version is used, we don't need to do anything else
+ if op.type == Op.AvgPool:
+ return op
# Make unit weight tensor quantization
weight_quant = ifmq.clone()
@@ -1561,6 +1562,8 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
)
op.weights.quant_values = np.reshape(op.inputs[1].quant_values, weight_shape)
+ # Add None bias tensor
+ op.inputs.append(None)
# Add bias tensor
if bias:
bias_shape = [shape[-1]]
@@ -1643,7 +1646,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
op_rewrite_list = [
set_tensor_equivalence,
- convert_mean_to_depthwise_conv,
+ convert_mean_to_depthwise_conv_or_avgpool,
convert_depthwise_to_conv,
convert_conv_to_fc,
convert_softmax,
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 777e9c70..5bf2c459 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -122,6 +122,7 @@ class SupportedOperators:
filter_product_range = (1, 256 * 256)
mean_kernel_product = 64 * 64
mean_kernel_product_int8 = 16 * 16
+ mean_kernel_product_avgpool = 256 * 256
# Supported consumers
supported_pad_consumers = convolution_ops | depthwise_convolution_ops | pooling_ops
@@ -272,6 +273,7 @@ class SupportedOperators:
self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_input_8bit)
self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_input_dims)
self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_axis)
+ self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_avgpool)
self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product)
self.specific_constraints[Op.Mean].append(SupportedOperators.constraint_mean_height_width_product_int8)
@@ -1028,6 +1030,7 @@ class SupportedOperators:
valid = len(op.ifm.shape) == len(op.ofm.shape)
return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
+ @staticmethod
def constraint_mean_input_dims(op):
"Input tensor must be at least 2D"
dims = len(op.inputs[0].shape)
@@ -1045,9 +1048,25 @@ class SupportedOperators:
return valid, f"Axis is {axis}"
@classmethod
+ @docstring_format_args([mean_kernel_product_avgpool])
+ def constraint_mean_height_width_product_avgpool(cls, op):
+ """Product of height and width can be at most {}"""
+ shape = op.inputs[0].shape
+ hi = 0 if len(shape) < 4 else 1
+ h, w = shape[hi : hi + 2]
+ max_prod = cls.mean_kernel_product_avgpool
+ return h * w <= max_prod, f"Product of height and width is {h * w}"
+
+ @classmethod
@docstring_format_args([mean_kernel_product])
def constraint_mean_height_width_product(cls, op):
- "Product of height and width can be at most {}"
+ """Product of height and width can be at most {} when IFM and OFM have different scale or zero point,
+ or keep_dims is True"""
+ ifmq, ofmq = op.ifm.quantization, op.ofm.quantization
+ keep_dims = op.attrs.get("keep_dims")
+ # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
+ if not keep_dims and ifmq.scale_f32 == ofmq.scale_f32 and ifmq.zero_point == ofmq.zero_point:
+ return True, ""
shape = op.inputs[0].shape
hi = 0 if len(shape) < 4 else 1
h, w = shape[hi : hi + 2]
@@ -1064,6 +1083,8 @@ class SupportedOperators:
IFM datatype is int8"""
shape = op.ifm.shape
axis = op.inputs[1].values if op.inputs[1].shape == [] else list(op.inputs[1].values)
+ # doesn't apply, size is checked by constraint_mean_height_width_product_avgpool
+ # and constraint_mean_height_width_product
if (
len(shape) != 4
or op.ifm.dtype != DataType.int8
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index aad2849a..355b472c 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -864,7 +864,7 @@ def test_mean_axis():
def test_mean_hw_product():
- op = create_mean([1, 64, 64, 16], [1, 1, 16], [1, 2], DataType.uint8, {})
+ op = create_mean([1, 64, 64, 16], [1, 16], [1, 2], DataType.uint8, {})
assert support.is_operator_supported(op)
op = create_mean([1, 65, 64, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
@@ -875,3 +875,10 @@ def test_mean_hw_product_int8():
assert support.is_operator_supported(op)
op = create_mean([1, 16, 17, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
assert not support.is_operator_supported(op)
+
+
+def test_mean_hw_product_avgpool():
+ op = create_mean([1, 200, 200, 16], [1, 16], [1, 2], DataType.uint8, {"keep_dims": False})
+ assert support.is_operator_supported(op)
+ op = create_mean([1, 200, 200, 16], [1, 1, 1, 16], [1, 2], DataType.int8, {"keep_dims": True})
+ assert not support.is_operator_supported(op)