aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 1310ee63..028151fd 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1752,9 +1752,15 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
op.set_ifm_ofm_shapes()
# If height is greater than max kernel height, reshape from HxW to 1x(HxW)
+ weight_shape = None
if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
+ # This can only happen and be done for multiple axes, and
+ # h * w <= 256 for DepthwiseConv2DBias
+ # h * w <= 4096 for AvgPool
+ # which is checked in supported ops
shape = [shape[0], 1, h * w, shape[3]]
op.ifm_shapes[0] = Shape4D(shape)
+ weight_shape = [1, h * w, shape[3], shape[0]]
if h > 256 and op.type == Op.AvgPool:
op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
@@ -1769,8 +1775,9 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
weight_quant.scale_f32 = weight_scale
weight_quant.zero_point = 0
- # Set weight shape to [H,W,C,B]
- weight_shape = [h, w, shape[3], shape[0]]
+ if weight_shape is None:
+ # Set weight shape to [H,W,C,B]
+ weight_shape = [h, w, shape[3], shape[0]]
# Add unit weight tensor
op.set_input_tensor(