diff options
author | Johan Alfvén <johan.alfven@arm.com> | 2022-10-13 10:49:30 +0200 |
---|---|---|
committer | Johan Alfvén <johan.alfven@arm.com> | 2022-10-18 10:06:11 +0200 |
commit | 65835e0118935e37e740d3d6fa2025549f31a2e0 (patch) | |
tree | 997d805d123d717d7e7a94f0158652851db86c77 /ethosu/vela/tflite_graph_optimiser.py | |
parent | a64616c4d7a33c5b2b4e5fb38c57217dc65bc2ea (diff) | |
download | ethos-u-vela-65835e0118935e37e740d3d6fa2025549f31a2e0.tar.gz |
MLBEDSW-6941: Set correct OFM shape for fc op
If IFM operator shape is rewritten so that batching
is greater than one for fully connect, the OFM batch
must also be calculated. This change will fix output diffs
for networks that have fully connect OFM with rank greater
than 2.
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I5009edc647a1449a02c8116b45808c1c68beffe6
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 574d298a..0ba5abf5 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -634,6 +634,12 @@ def rewrite_fully_connected_input(op: Operation, arch, nng): new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2]) assert new_shape is not None, "Tensor can not be reshaped to 2D" op.ifm_shapes[0] = new_shape + + if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1: + # If IFM is batching then also make sure OFM is batching + h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width + op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth]) + return op |