diff options
author | Johan Alfvén <johan.alfven@arm.com> | 2022-09-26 13:46:51 +0200 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2022-09-27 12:18:43 +0000 |
commit | e84ed6b6663bb9158cd87d11cb21b48abed1033d (patch) | |
tree | f053cb4007aaebcd7760ae593482197aad45f792 /ethosu | |
parent | fd0a338fa52844e33d922a8084819c2437ff16fa (diff) | |
download | ethos-u-vela-e84ed6b6663bb9158cd87d11cb21b48abed1033d.tar.gz |
MLBEDSW-6962: MEAN height is greater than max kernel height
Fixed bug when height is greater than max kernel height. The shape
of the weight must match the ifm shape.
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I901a8af2edd5858bb15d53d85ef8e2389049ada7
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 1310ee63..028151fd 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1752,9 +1752,15 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): op.set_ifm_ofm_shapes() # If height is greater than max kernel height, reshape from HxW to 1x(HxW) + weight_shape = None if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool): + # This can only happen and be done for multiple axes, and + # h * w <= 256 for DepthwiseConv2DBias + # h * w <= 4096 for AvgPool + # which is checked in supported ops shape = [shape[0], 1, h * w, shape[3]] op.ifm_shapes[0] = Shape4D(shape) + weight_shape = [1, h * w, shape[3], shape[0]] if h > 256 and op.type == Op.AvgPool: op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w}) @@ -1769,8 +1775,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): weight_quant.scale_f32 = weight_scale weight_quant.zero_point = 0 - # Set weight shape to [H,W,C,B] - weight_shape = [h, w, shape[3], shape[0]] + if weight_shape is None: + # Set weight shape to [H,W,C,B] + weight_shape = [h, w, shape[3], shape[0]] # Add unit weight tensor op.set_input_tensor( |