aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index e01433d0..7b10f86a 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1301,8 +1301,8 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
if dims == 2:
shape += [1]
- # If height is greater than max kernel height, reshape to from HxW to 1x(HxW)
- if h > 64:
+ # If height is greater than max kernel height, reshape from HxW to 1x(HxW)
+ if (h > 64 and op.type == Op.DepthwiseConv2DBias) or (h > 256 and op.type == Op.AvgPool):
shape = [shape[0], 1, h * w, shape[3]]
op.ifm_shapes[0] = Shape4D(shape)
if h > 256 and op.type == Op.AvgPool: