aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-21 11:56:18 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-09-28 10:21:20 +0000
commitfa2f92a51246630532d53b24228b7620b66595d1 (patch)
tree0b48513eeaf1d8d23e4b5258680bae9f5ab80cde /ethosu/vela/supported_operators.py
parentf3d737ea14eabffede935cb418611b1f624e180a (diff)
downloadethos-u-vela-fa2f92a51246630532d53b24228b7620b66595d1.tar.gz
MLBEDSW-3035: Updated StridedSlice checks
Updated supported operator checks for StridedSlice: - allow negative indices in begin/end values - added more checks on shapes Change-Id: I3ac76bfa6b313f0e2250f0749f152fb0e3aa033c Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py41
1 files changed, 37 insertions, 4 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 63eb01b5..9e9da8c6 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -19,6 +19,11 @@ import numpy as np
from .data_type import BaseType
from .data_type import DataType
+from .operation import get_slice_offsets
+
+
+def warn_cpu(op, msg):
+ print("Warning: {} {}, placing on CPU".format(op.type, msg))
class SupportedOperators:
@@ -381,17 +386,45 @@ class SupportedOperators:
def check_memory_only_restrictions(self, op):
if op.type == "StridedSlice":
- # check stride size
- if len(op.inputs) > 3 and any(stride != 1 for stride in op.inputs[3].values):
+ if len(op.inputs) != 4:
+ warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs)))
return False
- # check "end - begin" doesnt result in any zero or negative elements
- if any((end - begin) <= 0 for begin, end in zip(op.inputs[1].values, op.inputs[2].values)):
+ input_tens, begin_tens, end_tens, strides_tens = op.inputs
+ if begin_tens.values is None or end_tens.values is None or strides_tens.values is None:
+ warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported")
+ return False
+ if not (
+ len(input_tens.shape)
+ == len(op.outputs[0].shape)
+ == len(begin_tens.values)
+ == len(end_tens.values)
+ == len(strides_tens.values)
+ ):
+ warn_cpu(op, "has input tensors with shapes that are not supported")
+ return False
+ # check stride size
+ if any(stride != 1 for stride in strides_tens.values):
+ warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values))
return False
# check ellipsis_mask
if op.attrs["ellipsis_mask"] != 0:
+ warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"]))
return False
# check if both new_axis_mask and shrink_axis_mask have bit set
if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0:
+ warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported")
+ return False
+ # Calculate offset start/end
+ offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True)
+ offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False)
+ # check "end - begin" doesn't result in any zero or negative elements
+ if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)):
+ warn_cpu(
+ op,
+ "has slice begin values {}, some of which are >= end values {}, which is illegal".format(
+ begin_tens.values, end_tens.values
+ ),
+ )
return False
if op.type == "SplitV":
# check that maximum one size is set to -1, indicating that size should be inferred