aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py35
1 files changed, 22 insertions, 13 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index a12eeb37..31d3ae1a 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2004,16 +2004,9 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
intermediate_shape.insert(i, 1)
# Reshape to 4D
- if dims == 2:
- # Reshape WxC -> 1xHxWx1 to support both axes
- reduce_axis = [False] + reduce_axis + [False]
- ifm_shape = [1] + ifm_shape + [1]
- intermediate_shape = [1] + intermediate_shape + [1]
- elif dims == 3:
- # Reshape to 4D HxWxC -> 1xHxWxC
- reduce_axis = [False] + reduce_axis
- ifm_shape = [1] + ifm_shape
- intermediate_shape = [1] + intermediate_shape
+ reduce_axis = full_shape(4, reduce_axis, False)
+ ifm_shape = full_shape(4, ifm_shape, 1)
+ intermediate_shape = full_shape(4, intermediate_shape, 1)
# If all dimensions to reduce have shape 1, the operation is essentially a memcpy.
# We can then remove the whole op by propagating ofm to previous ops
@@ -2022,9 +2015,25 @@ def convert_mean_to_depthwise_conv(op, arch, nng):
op = bypass_memory_only_ops(op, arch, nng)
return op
- # Compute kernel sizes for our convolutions.
- # batch and depth axes are only supported if their shapes are 1.
- # hence reduction in batch or depth axis is implicit.
+ # Support mean over depth-axis by left-shifting the C channel
+ # From semantics checks we can assume that one of H,W,C has shape 1
+ if reduce_axis[3] and ifm_shape[3] > 1:
+ assert 1 in ifm_shape[1:], "Mean reduction over depth channel, but none of H,W,C has shape 1"
+ # If W=1 reshape NxHx1xC -> NxHxCx1, else reshape Nx1xWxC -> NxWxCx1
+ idx_to_del = 2 if ifm_shape[2] == 1 else 1
+
+ # Delete axis with size 1
+ del reduce_axis[idx_to_del]
+ del ifm_shape[idx_to_del]
+ del intermediate_shape[idx_to_del]
+
+ # Add another element to set channel-axis to one
+ reduce_axis.append(False)
+ ifm_shape.append(1)
+ intermediate_shape.append(1)
+
+ # Compute kernel sizes for our convolutions
+ # Batch axis is implicit as it is only supported if batch size is 1.
h = ifm_shape[1] if reduce_axis[1] else 1
w = ifm_shape[2] if reduce_axis[2] else 1