diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-04-29 14:10:32 +0200 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-06-18 17:53:52 +0100 |
commit | 4913433e433d379425b210254dc9589fa63f516a (patch) | |
tree | 50ee8d26ab977ef4a89d0bf567a6fb4480b33170 | |
parent | f995db7b503eb2e5690972d95f40b96199c5555c (diff) | |
download | ethos-u-vela-4913433e433d379425b210254dc9589fa63f516a.tar.gz |
MLBEDSW-1998: Add more support for Strided_slice
Add support for end_mask != begin_mask
Change-Id: I6775696de4e2365e0a7cdcbcdbc64a7bd4858fb5
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
-rw-r--r-- | ethosu/vela/operation.py | 20 |
1 files changed, 9 insertions, 11 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index d2f2806a..f36e61c6 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -250,7 +250,6 @@ input and output tensors, as well as an attribute dictionary.""" shrink_axis_mask = self.attrs["shrink_axis_mask"] # TODO: Either extend this to support these different masks or check # for this at an earlier stage and place the op on Cpu if needed - assert begin_mask == end_mask assert new_axis_mask == ellipsis_mask == 0 # shrink_axis_mask is not supported by the Operation class but the operation # may have the attribute modified and handled in the graph optimization phase. @@ -258,18 +257,17 @@ input and output tensors, as well as an attribute dictionary.""" assert len(input_tens.shape) == len(out_tens.shape) for idx in range(len(input_tens.shape)): - # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored - if (begin_mask & (1 << idx)) == 0: - # Check if the op should slice in dimension idx - if end_tens.values[idx] != input_tens.shape[idx] or ( - end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0 - ): + # Check if slicing is needed in this axis + if end_tens.values[idx] != input_tens.shape[idx] or ( + end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0 + ): + # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored + if (begin_mask & (1 << idx)) == 0: offset_start[idx] = begin_tens.values[idx] - offset_end[idx] = end_tens.values[idx] - else: - # Don't slice in this axis, instead use fullest possible range - continue + # If the i:th bit in end_mask is set then the value on end[i] should be ignored + if (end_mask & (1 << idx)) == 0: + offset_end[idx] = end_tens.values[idx] elif self.type == "UnpackReshaped": # Requires fixup_unpack_output to be called before this point |