aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-09-26 13:46:51 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2022-09-27 12:18:43 +0000
commite84ed6b6663bb9158cd87d11cb21b48abed1033d (patch)
treef053cb4007aaebcd7760ae593482197aad45f792
parentfd0a338fa52844e33d922a8084819c2437ff16fa (diff)
downloadethos-u-vela-e84ed6b6663bb9158cd87d11cb21b48abed1033d.tar.gz
MLBEDSW-6962: MEAN height is greater than max kernel height
Fixed bug when height is greater than max kernel height. The shape of the weight must match the ifm shape. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I901a8af2edd5858bb15d53d85ef8e2389049ada7
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 1310ee63..028151fd 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1752,9 +1752,15 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
op.set_ifm_ofm_shapes()
# If height is greater than max kernel height, reshape from HxW to 1x(HxW)
+ weight_shape = None
if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
+ # This can only happen and be done for multiple axes, and
+ # h * w <= 256 for DepthwiseConv2DBias
+ # h * w <= 4096 for AvgPool
+ # which is checked in supported ops
shape = [shape[0], 1, h * w, shape[3]]
op.ifm_shapes[0] = Shape4D(shape)
+ weight_shape = [1, h * w, shape[3], shape[0]]
if h > 256 and op.type == Op.AvgPool:
op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
@@ -1769,8 +1775,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
weight_quant.scale_f32 = weight_scale
weight_quant.zero_point = 0
- # Set weight shape to [H,W,C,B]
- weight_shape = [h, w, shape[3], shape[0]]
+ if weight_shape is None:
+ # Set weight shape to [H,W,C,B]
+ weight_shape = [h, w, shape[3], shape[0]]
# Add unit weight tensor
op.set_input_tensor(