diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 574d298a..0ba5abf5 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -634,6 +634,12 @@ def rewrite_fully_connected_input(op: Operation, arch, nng): new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2]) assert new_shape is not None, "Tensor can not be reshaped to 2D" op.ifm_shapes[0] = new_shape + + if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1: + # If IFM is batching then also make sure OFM is batching + h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width + op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth]) + return op |