aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py31
1 files changed, 15 insertions, 16 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6bc5a32d..252f03b7 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -43,6 +43,19 @@ def create_avgpool_nop(name):
return op
+def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
+ # For strided slice operator: get start or end offsets
+ offsets = len(input_shape) * [0] if is_begin else input_shape[:]
+ for idx in range(len(input_shape)):
+ # If the i:th bit in the mask is set then the value on offset_tens[i] should be ignored
+ if (offset_mask & (1 << idx)) == 0:
+ offsets[idx] = offset_tens.values[idx]
+ if offsets[idx] < 0:
+ # Convert offset to positive value
+ offsets[idx] += input_shape[idx]
+ return offsets
+
+
class Operation:
"""Class representing a Neural Network operation. Has a name, a type,
input and output tensors, as well as an attribute dictionary."""
@@ -309,8 +322,6 @@ input and output tensors, as well as an attribute dictionary."""
input_tens, begin_tens, end_tens, strides_tens = self.inputs
outputs = self.outputs
out_tens = outputs[0]
- offset_start = [0] * len(outputs[0].shape)
- offset_end = [0] * len(outputs[0].shape)
# Extract masks
begin_mask = self.attrs["begin_mask"]
@@ -323,20 +334,8 @@ input and output tensors, as well as an attribute dictionary."""
# may have the attribute modified and handled in the graph optimization phase.
assert shrink_axis_mask == new_axis_mask == ellipsis_mask == 0
assert len(input_tens.shape) == len(out_tens.shape)
-
- for idx in range(len(input_tens.shape)):
- # Check if slicing is needed in this axis
- if end_tens.values[idx] != input_tens.shape[idx] or (
- end_tens.values[idx] == input_tens.shape[idx] and begin_tens.values[idx] != 0
- ):
- # If the i:th bit in begin_mask is set then the value on begin[i] should be ignored
- if (begin_mask & (1 << idx)) == 0:
- offset_start[idx] = begin_tens.values[idx]
-
- # If the i:th bit in end_mask is set then the value on end[i] should be ignored
- if (end_mask & (1 << idx)) == 0:
- offset_end[idx] = end_tens.values[idx]
-
+ offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
+ offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
elif self.type == "UnpackReshaped":
# Requires fixup_unpack_output to be called before this point
input_tens = self.inputs[0]