aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-13 10:49:30 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-18 10:06:11 +0200
commit65835e0118935e37e740d3d6fa2025549f31a2e0 (patch)
tree997d805d123d717d7e7a94f0158652851db86c77
parenta64616c4d7a33c5b2b4e5fb38c57217dc65bc2ea (diff)
downloadethos-u-vela-65835e0118935e37e740d3d6fa2025549f31a2e0.tar.gz
MLBEDSW-6941: Set correct OFM shape for fc op
If IFM operator shape is rewritten so that batching is greater than one for fully connect, the OFM batch must also be calculated. This change will fix output diffs for networks that have fully connect OFM with rank greater than 2. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I5009edc647a1449a02c8116b45808c1c68beffe6
-rw-r--r--ethosu/vela/operation.py6
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py6
2 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 8189793e..4a56f1f0 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -879,10 +879,8 @@ class Operation:
else:
# Special case, handled in graph optimization
self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
- if len(self.ofm.shape) == 2:
- self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]]))
- else:
- self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
+ self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
+
elif self.type == Op.Softmax:
self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 574d298a..0ba5abf5 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -634,6 +634,12 @@ def rewrite_fully_connected_input(op: Operation, arch, nng):
new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
assert new_shape is not None, "Tensor can not be reshaped to 2D"
op.ifm_shapes[0] = new_shape
+
+ if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
+ # If IFM is batching then also make sure OFM is batching
+ h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
+ op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
+
return op