From a2ec5aa747633da72b6310ce7e5552c39f7f54bb Mon Sep 17 00:00:00 2001 From: Ayaan Masood Date: Thu, 21 Apr 2022 14:28:03 +0100 Subject: MLBEDSW-5384 FC layers run on NPU if underlying shape is 2D *Added generic function which checks if underlying shape of FullyConnected operation is 2D and performs shape reduction *Fully connected operation >2 dimensions now run on NPU if the above case is satisfied *constraint_fc_output_2d and rewrite_fully_connected_input refactored *Added unit test to confirm this functionality Signed-off-by: Ayaan Masood Change-Id: I0e29c767e5b84841eb53bbc44464b36a454f7b38 --- ethosu/vela/tflite_graph_optimiser.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'ethosu/vela/tflite_graph_optimiser.py') diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index b2a34195..06395784 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -379,14 +379,12 @@ def convert_nop_split_to_identity(op, arch, nng): return op -def rewrite_fully_connected_input(op, arch, nng): - if op.type == Op.FullyConnected: - n_in_elems = op.weights.shape[-2] - elms = op.ifm.elements() - batch_size = elms // n_in_elems - assert batch_size * n_in_elems == elms +def rewrite_fully_connected_input(op: Operation, arch, nng): - op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems]) + if op.type == Op.FullyConnected: + new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2]) + assert new_shape is not None, "Tensor can not be reshaped to 2D" + op.ifm_shapes[0] = new_shape return op -- cgit v1.2.1