From fa2f92a51246630532d53b24228b7620b66595d1 Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Mon, 21 Sep 2020 11:56:18 +0200 Subject: MLBEDSW-3035: Updated StridedSlice checks Updated supported operator checks for StridedSlice: - allow negative indices in begin/end values - added more checks on shapes Change-Id: I3ac76bfa6b313f0e2250f0749f152fb0e3aa033c Signed-off-by: Louis Verhaard --- ethosu/vela/supported_operators.py | 41 ++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) (limited to 'ethosu/vela/supported_operators.py') diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 63eb01b5..9e9da8c6 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -19,6 +19,11 @@ import numpy as np from .data_type import BaseType from .data_type import DataType +from .operation import get_slice_offsets + + +def warn_cpu(op, msg): + print("Warning: {} {}, placing on CPU".format(op.type, msg)) class SupportedOperators: @@ -381,17 +386,45 @@ class SupportedOperators: def check_memory_only_restrictions(self, op): if op.type == "StridedSlice": - # check stride size - if len(op.inputs) > 3 and any(stride != 1 for stride in op.inputs[3].values): + if len(op.inputs) != 4: + warn_cpu(op, "has {} input tensors, only 4 inputs are supported".format(len(op.inputs))) return False - # check "end - begin" doesnt result in any zero or negative elements - if any((end - begin) <= 0 for begin, end in zip(op.inputs[1].values, op.inputs[2].values)): + input_tens, begin_tens, end_tens, strides_tens = op.inputs + if begin_tens.values is None or end_tens.values is None or strides_tens.values is None: + warn_cpu(op, "has a non-constant begin, end, or stride input tensor, which is not supported") + return False + if not ( + len(input_tens.shape) + == len(op.outputs[0].shape) + == len(begin_tens.values) + == len(end_tens.values) + == len(strides_tens.values) + ): + warn_cpu(op, "has input tensors with shapes that are not supported") + return False + # check stride size + if any(stride != 1 for stride in strides_tens.values): + warn_cpu(op, "has stride values {}, only stride 1 values are supported".format(strides_tens.values)) return False # check ellipsis_mask if op.attrs["ellipsis_mask"] != 0: + warn_cpu(op, "ellipsis_mask is {}, only 0 is supported".format(op.attrs["ellipsis_mask"])) return False # check if both new_axis_mask and shrink_axis_mask have bit set if op.attrs["new_axis_mask"] != 0 and op.attrs["shrink_axis_mask"] != 0: + warn_cpu(op, "new_axis_mask and shrink_axis_mask are both non-zero, which is not supported") + return False + # Calculate offset start/end + offset_start = get_slice_offsets(input_tens.shape, begin_tens, op.attrs["begin_mask"], is_begin=True) + offset_end = get_slice_offsets(input_tens.shape, end_tens, op.attrs["end_mask"], is_begin=False) + # check "end - begin" doesn't result in any zero or negative elements + if any((end - begin) <= 0 for begin, end in zip(offset_start, offset_end)): + warn_cpu( + op, + "has slice begin values {}, some of which are >= end values {}, which is illegal".format( + begin_tens.values, end_tens.values + ), + ) return False if op.type == "SplitV": # check that maximum one size is set to -1, indicating that size should be inferred -- cgit v1.2.1