From f49370003956d4f6f7d177114a68edb07b364fe9 Mon Sep 17 00:00:00 2001 From: Johan Alfven Date: Sat, 20 Apr 2024 08:04:27 +0200 Subject: MLBEDSW-8969: Enable weight buffering for fully connected with batch shape - Fully connected with batch shape will use the weights more than once. Models with these type of fully connected will benefit from weight buffering. - If a fully connected op with this shape is detected it is changed to a conv2d and the normal weight buffering flow will be used. Change-Id: I272741a32390e036d5e04bd5af41d4538162e86e Signed-off-by: Johan Alfven --- ethosu/vela/tflite_graph_optimiser.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 687e5d4f..3af8588c 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -827,7 +827,7 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation: if op.type == Op.FullyConnected: # Check if the first dimension indicates batching if op.ifm_shapes[0].batch > 1: - batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)} + batching_split = {4: (2, 2), 6: (2, 3), 8: (2, 4), 9: (3, 3), 12: (3, 4), 16: (4, 4)} n = op.ifm_shapes[0].batch h, w = batching_split.get(n, (1, n)) op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth]) @@ -840,6 +840,13 @@ def convert_batched_fc_shape(op: Operation, arch, nng) -> Operation: n = op.ofm_shapes[0].batch h, w = batching_split.get(n, (1, n)) op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth]) + if h == 1 and w > 4: + # If batch can not be found in the split set the weights are going to be + # read from memory several times. Convert op to conv2d since this + # enables weight buffering. + op.type = Op.Conv2DBias + op.attrs["padding"] = Padding.SAME + DebugDatabase.add_optimised(op, op) return op -- cgit v1.2.1