aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-13 10:49:30 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-18 10:06:11 +0200
commit65835e0118935e37e740d3d6fa2025549f31a2e0 (patch)
tree997d805d123d717d7e7a94f0158652851db86c77 /ethosu/vela
parenta64616c4d7a33c5b2b4e5fb38c57217dc65bc2ea (diff)
downloadethos-u-vela-65835e0118935e37e740d3d6fa2025549f31a2e0.tar.gz
MLBEDSW-6941: Set correct OFM shape for fc op
If IFM operator shape is rewritten so that batching is greater than one for fully connect, the OFM batch must also be calculated. This change will fix output diffs for networks that have fully connect OFM with rank greater than 2. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I5009edc647a1449a02c8116b45808c1c68beffe6
Diffstat (limited to 'ethosu/vela')
-rw-r--r--ethosu/vela/operation.py6
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py6
2 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 8189793e..4a56f1f0 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -879,10 +879,8 @@ class Operation:
else:
# Special case, handled in graph optimization
self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
- if len(self.ofm.shape) == 2:
- self.ofm_shapes.append(Shape4D([self.ofm.shape[0], 1, 1, self.ofm.shape[1]]))
- else:
- self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
+ self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
+
elif self.type == Op.Softmax:
self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 574d298a..0ba5abf5 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -634,6 +634,12 @@ def rewrite_fully_connected_input(op: Operation, arch, nng):
new_shape = op.ifm.get_shape_as_2d(op.weights.shape[-2])
assert new_shape is not None, "Tensor can not be reshaped to 2D"
op.ifm_shapes[0] = new_shape
+
+ if op.ifm_shapes[0].batch > 1 and op.ofm_shapes[0].batch == 1:
+ # If IFM is batching then also make sure OFM is batching
+ h, w = op.ofm_shapes[0].height, op.ofm_shapes[0].width
+ op.ofm_shapes[0] = Shape4D([h * w, 1, 1, op.ofm_shapes[0].depth])
+
return op