diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-09-01 09:15:27 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-09-03 11:00:51 +0000 |
commit | 271ddc31745ca2aefa193f3a7308753126ac7c89 (patch) | |
tree | 3cee2a4446eeda79151fb78c27d46ba15fd08c5d /ethosu | |
parent | 0628a8c0136eebf3af8db7fd40b7aed94ff5d670 (diff) | |
download | ethos-u-vela-271ddc31745ca2aefa193f3a7308753126ac7c89.tar.gz |
MLBEDSW-2814 Add support for inferred size in SplitV
For SplitV sizesplit can contain one -1 indicating that
dimension is to be inferred.
Support added to handle this.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ib9fc8dd2ee1749e81a978d85f2d4a016698bb441
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/operation.py | 8 | ||||
-rw-r--r-- | ethosu/vela/supported_operators.py | 12 |
2 files changed, 20 insertions, 0 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 8dec379d..4b83b39b 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -259,9 +259,17 @@ input and output tensors, as well as an attribute dictionary.""" size_tens = self.inputs[1] assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const" sizes = size_tens.values + axis_tens = self.inputs[2] assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const" axis = int(axis_tens.values) + + for idx, size in enumerate(sizes): + # One but only one size might be set to -1, indicating that size should be inferred + if size == -1: + sizes[idx] = input_tens.shape[axis] - (sum(sizes) + 1) + break + outputs = self.outputs assert num_splits == len(outputs) assert sum(sizes) == input_tens.shape[axis] diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 7cff0ee4..e0ee6163 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -378,6 +378,18 @@ class SupportedOperators: # check if both new_axis_mask and shrink_axis_mask have bit set if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0: return False + if op.type == "SplitV": + # check that maximum one size is set to -1, indicating that size should be inferred + sizes = op.inputs[1].values + num_to_be_inferred = 0 + for size in sizes: + if size == -1: + num_to_be_inferred += 1 + + if num_to_be_inferred > 1: + print("Warning:", op.type, "has more than one size to be inferred, which is illegal, placing on CPU") + return False + return True def check_quantization_restrictions_binary_elem_wise(self, op): |