aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 687e5d4..3af8588 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -827,7 +827,7 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
if op.type == Op.FullyConnected:
# Check if the first dimension indicates batching
if op.ifm_shapes[0].batch > 1:
- batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
+ batching_split = {4: (2, 2), 6: (2, 3), 8: (2, 4), 9: (3, 3), 12: (3, 4), 16: (4, 4)}
n = op.ifm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
@@ -840,6 +840,13 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation:
n = op.ofm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
+ if h == 1 and w > 4:
+ # If batch can not be found in the split set the weights are going to be
+ # read from memory several times. Convert op to conv2d since this
+ # enables weight buffering.
+ op.type = Op.Conv2DBias
+ op.attrs["padding"] = Padding.SAME
+ DebugDatabase.add_optimised(op, op)
return op