aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-04-29 14:10:32 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit4913433e433d379425b210254dc9589fa63f516a (patch)
tree50ee8d26ab977ef4a89d0bf567a6fb4480b33170 /ethosu
parentf995db7b503eb2e5690972d95f40b96199c5555c (diff)
downloadethos-u-vela-4913433e433d379425b210254dc9589fa63f516a.tar.gz
MLBEDSW-1998: Add more support for Strided_slice
Add support for end_mask != begin_mask Change-Id: I6775696de4e2365e0a7cdcbcdbc64a7bd4858fb5 Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/operation.py20
1 files changed, 9 insertions, 11 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index d2f2806a..f36e61c6 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -250,7 +250,6 @@ input and output tensors, as well as an attribute dictionary."""
shrink_axis_mask = self.attrs["shrink_axis_mask"]
# TODO: Either extend this to support these different masks or check
# for this at an earlier stage and place the op on Cpu if needed
- assert begin_mask == end_mask
assert new_axis_mask == ellipsis_mask == 0
# shrink_axis_mask is not supported by the Operation class but the operation
# may have the attribute modified and handled in the graph optimization phase.
@@ -258,18 +257,17 @@ input and output tensors, as well as an attribute dictionary."""
assert len(input_tens.shape) == len(out_tens.shape)
for idx in range(len(input_tens.shape)):
- # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
- if (begin_mask & (1 << idx)) == 0:
- # Check if the op should slice in dimension idx
- if end_tens.values[idx] != input_tens.shape[idx] or (
- end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
- ):
+ # Check if slicing is needed in this axis
+ if end_tens.values[idx] != input_tens.shape[idx] or (
+ end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
+ ):
+ # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
+ if (begin_mask & (1 << idx)) == 0:
offset_start[idx] = begin_tens.values[idx]
- offset_end[idx] = end_tens.values[idx]
- else:
- # Don't slice in this axis, instead use fullest possible range
- continue
+ # If the i:th bit in end_mask is set then the value on end[i] should be ignored
+ if (end_mask & (1 << idx)) == 0:
+ offset_end[idx] = end_tens.values[idx]
elif self.type == "UnpackReshaped":
# Requires fixup_unpack_output to be called before this point