diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index a12eeb37..31d3ae1a 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2004,16 +2004,9 @@ def convert_mean_to_depthwise_conv(op, arch, nng): intermediate_shape.insert(i, 1) # Reshape to 4D - if dims == 2: - # Reshape WxC -> 1xHxWx1 to support both axes - reduce_axis = [False] + reduce_axis + [False] - ifm_shape = [1] + ifm_shape + [1] - intermediate_shape = [1] + intermediate_shape + [1] - elif dims == 3: - # Reshape to 4D HxWxC -> 1xHxWxC - reduce_axis = [False] + reduce_axis - ifm_shape = [1] + ifm_shape - intermediate_shape = [1] + intermediate_shape + reduce_axis = full_shape(4, reduce_axis, False) + ifm_shape = full_shape(4, ifm_shape, 1) + intermediate_shape = full_shape(4, intermediate_shape, 1) # If all dimensions to reduce have shape 1, the operation is essentially a memcpy. # We can then remove the whole op by propagating ofm to previous ops @@ -2022,9 +2015,25 @@ def convert_mean_to_depthwise_conv(op, arch, nng): op = bypass_memory_only_ops(op, arch, nng) return op - # Compute kernel sizes for our convolutions. - # batch and depth axes are only supported if their shapes are 1. - # hence reduction in batch or depth axis is implicit. + # Support mean over depth-axis by left-shifting the C channel + # From semantics checks we can assume that one of H,W,C has shape 1 + if reduce_axis[3] and ifm_shape[3] > 1: + assert 1 in ifm_shape[1:], "Mean reduction over depth channel, but none of H,W,C has shape 1" + # If W=1 reshape NxHx1xC -> NxHxCx1, else reshape Nx1xWxC -> NxWxCx1 + idx_to_del = 2 if ifm_shape[2] == 1 else 1 + + # Delete axis with size 1 + del reduce_axis[idx_to_del] + del ifm_shape[idx_to_del] + del intermediate_shape[idx_to_del] + + # Add another element to set channel-axis to one + reduce_axis.append(False) + ifm_shape.append(1) + intermediate_shape.append(1) + + # Compute kernel sizes for our convolutions + # Batch axis is implicit as it is only supported if batch size is 1. h = ifm_shape[1] if reduce_axis[1] else 1 w = ifm_shape[2] if reduce_axis[2] else 1 |