From 4913433e433d379425b210254dc9589fa63f516a Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 29 Apr 2020 14:10:32 +0200 Subject: MLBEDSW-1998: Add more support for Strided_slice Add support for end_mask != begin_mask Change-Id: I6775696de4e2365e0a7cdcbcdbc64a7bd4858fb5 Signed-off-by: Patrik Gustavsson --- ethosu/vela/operation.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index d2f2806a..f36e61c6 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -250,7 +250,6 @@ input and output tensors, as well as an attribute dictionary.""" shrink_axis_mask = self.attrs["shrink_axis_mask"] # TODO: Either extend this to support these different masks or check # for this at an earlier stage and place the op on Cpu if needed - assert begin_mask == end_mask assert new_axis_mask == ellipsis_mask == 0 # shrink_axis_mask is not supported by the Operation class but the operation # may have the attribute modified and handled in the graph optimization phase. @@ -258,18 +257,17 @@ input and output tensors, as well as an attribute dictionary.""" assert len(input_tens.shape) == len(out_tens.shape) for idx in range(len(input_tens.shape)): - # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored - if (begin_mask & (1 << idx)) == 0: - # Check if the op should slice in dimension idx - if end_tens.values[idx] != input_tens.shape[idx] or ( - end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0 - ): + # Check if slicing is needed in this axis + if end_tens.values[idx] != input_tens.shape[idx] or ( + end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0 + ): + # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored + if (begin_mask & (1 << idx)) == 0: offset_start[idx] = begin_tens.values[idx] - offset_end[idx] = end_tens.values[idx] - else: - # Don't slice in this axis, instead use fullest possible range - continue + # If the i:th bit in end_mask is set then the value on end[i] should be ignored + if (end_mask & (1 << idx)) == 0: + offset_end[idx] = end_tens.values[idx] elif self.type == "UnpackReshaped": # Requires fixup_unpack_output to be called before this point -- cgit v1.2.1