From e84ed6b6663bb9158cd87d11cb21b48abed1033d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Mon, 26 Sep 2022 13:46:51 +0200 Subject: MLBEDSW-6962: MEAN height is greater than max kernel height Fixed bug when height is greater than max kernel height. The shape of the weight must match the ifm shape. Signed-off-by: Johan Alfven Change-Id: I901a8af2edd5858bb15d53d85ef8e2389049ada7 --- ethosu/vela/tflite_graph_optimiser.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 1310ee63..028151fd 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1752,9 +1752,15 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): op.set_ifm_ofm_shapes() # If height is greater than max kernel height, reshape from HxW to 1x(HxW) + weight_shape = None if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool): + # This can only happen and be done for multiple axes, and + # h * w <= 256 for DepthwiseConv2DBias + # h * w <= 4096 for AvgPool + # which is checked in supported ops shape = [shape[0], 1, h * w, shape[3]] op.ifm_shapes[0] = Shape4D(shape) + weight_shape = [1, h * w, shape[3], shape[0]] if h > 256 and op.type == Op.AvgPool: op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w}) @@ -1769,8 +1775,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): weight_quant.scale_f32 = weight_scale weight_quant.zero_point = 0 - # Set weight shape to [H,W,C,B] - weight_shape = [h, w, shape[3], shape[0]] + if weight_shape is None: + # Set weight shape to [H,W,C,B] + weight_shape = [h, w, shape[3], shape[0]] # Add unit weight tensor op.set_input_tensor( -- cgit v1.2.1