aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-09-01 09:15:27 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-09-03 11:00:51 +0000
commit271ddc31745ca2aefa193f3a7308753126ac7c89 (patch)
tree3cee2a4446eeda79151fb78c27d46ba15fd08c5d /ethosu/vela/operation.py
parent0628a8c0136eebf3af8db7fd40b7aed94ff5d670 (diff)
downloadethos-u-vela-271ddc31745ca2aefa193f3a7308753126ac7c89.tar.gz
MLBEDSW-2814 Add support for inferred size in SplitV
For SplitV sizesplit can contain one -1 indicating that dimension is to be inferred. Support added to handle this. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Ib9fc8dd2ee1749e81a978d85f2d4a016698bb441
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 8dec379d..4b83b39b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -259,9 +259,17 @@ input and output tensors, as well as an attribute dictionary."""
size_tens = self.inputs[1]
assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
sizes = size_tens.values
+
axis_tens = self.inputs[2]
assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
axis = int(axis_tens.values)
+
+ for idx, size in enumerate(sizes):
+ # One but only one size might be set to -1, indicating that size should be inferred
+ if size == -1:
+ sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1)
+ break
+
outputs = self.outputs
assert num_splits == len(outputs)
assert sum(sizes) == input_tens.shape[axis]