From 65835e0118935e37e740d3d6fa2025549f31a2e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Thu, 13 Oct 2022 10:49:30 +0200 Subject: MLBEDSW-6941: Set correct OFM shape for fc op If IFM operator shape is rewritten so that batching is greater than one for fully connect, the OFM batch must also be calculated. This change will fix output diffs for networks that have fully connect OFM with rank greater than 2. Signed-off-by: Johan Alfven Change-Id: I5009edc647a1449a02c8116b45808c1c68beffe6 --- ethosu/vela/operation.py | 6 ++---- ethosu/vela/tflite_graph_optimiser.py | 6 ++++++ 2 files changed, 8 insertions(+), 4 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 8189793e..4a56f1f0 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -879,10 +879,8 @@ class Operation: else: # Special case, handled in graph optimization self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) - if len(self.ofm.shape) == 2: - self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]])) - else: - self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) + self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) + elif self.type == Op.Softmax: self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape())) self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape())) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 574d298a..0ba5abf5 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -634,6 +634,12 @@ def rewrite_fully_connected_input(op: Operation, arch, nng): new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2]) assert new_shape is not None, "Tensor can not be reshaped to 2D" op.ifm_shapes[0] = new_shape + + if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1: + # If IFM is batching then also make sure OFM is batching + h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width + op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth]) + return op -- cgit v1.2.1