aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/graph_optimiser.py16
-rw-r--r--ethosu/vela/supported_operators.py17
-rw-r--r--ethosu/vela/test/test_supported_operators.py6
3 files changed, 4 insertions, 35 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 32f97d2f..e31348b5 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -422,19 +422,12 @@ def unfuse_activation_function(op, arch, nng):
def fixup_stridedslice_output(tens, arch, nng):
op = tens.ops[0]
- if op.type == Op.StridedSlice:
+ if op.run_on_npu and op.type == Op.StridedSlice:
reshape_input_shape = tens.shape
new_axis_mask = op.attrs["new_axis_mask"]
shrink_axis_mask = op.attrs["shrink_axis_mask"]
- ellipsis_mask = op.attrs["ellipsis_mask"]
- if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0:
- # Not supported, will be put on CPU
- return tens
- if shrink_axis_mask == 0 and new_axis_mask == 0:
- # Equal Rank StridedSlice, no need to insert reshape
- return tens
- elif shrink_axis_mask != 0:
+ if shrink_axis_mask != 0:
n = 0
axis = 0
while shrink_axis_mask:
@@ -446,7 +439,6 @@ def fixup_stridedslice_output(tens, arch, nng):
assert len(tens.shape) == (len(op.inputs[0].shape) - n)
op.attrs["shrink_axis_mask"] = 0
-
elif new_axis_mask != 0:
n = 0
axis = 0
@@ -1092,7 +1084,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# rewrite graph pass
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [fixup_stridedslice_output], op_rewrite_list, rewrite_unsupported=False,
+ nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
)
for idx, sg in enumerate(nng.subgraphs):
@@ -1113,7 +1105,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False):
for idx, sg in enumerate(nng.subgraphs):
# combined rewrite graph pass
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [fixup_unpack_output, rewrite_concat, rewrite_split], []
+ nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], []
)
if verbose_graph:
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 04cda1da..3e649e09 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -219,7 +219,6 @@ class SupportedOperators:
# StridedSlice specific checks:
self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_input_count)
self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_inputs_const)
- self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_tens_size_matches)
self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_stride_values)
self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_ellipsis_mask)
self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_axis_masks)
@@ -728,22 +727,6 @@ class SupportedOperators:
return valid, f"Op has non-constant tensors: {extra}"
@staticmethod
- def constraint_stridedslice_tens_size_matches(op):
- "All Input sizes must match OFM size"
- ifm, begin, end, strides = op.inputs
- ifm_size = len(ifm.shape)
- ofm_size = len(op.ofm.shape)
- begin_size = len(begin.values)
- end_size = len(end.values)
- strides_size = len(strides.values)
- valid = ifm_size == ofm_size == begin_size == end_size == strides_size
- extra = (
- f"Op has ofm_size={ofm_size}, ifm_size={ifm_size},"
- f" begin_size={begin_size}, end_size={end_size} and strides_size={strides_size}"
- )
- return valid, extra
-
- @staticmethod
def constraint_stridedslice_stride_values(op):
"All Strides values must be 1"
strides = op.inputs[3]
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 595ea590..245ebcf9 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -486,12 +486,6 @@ def test_constraint_stridedslice_inputs_const():
assert not support.is_operator_supported(op)
-def test_constraint_stridedslice_tens_size_matches():
- op = create_strided_slice()
- op.inputs[1].values = [1, 1, 1, 1, 1, 1, 1, 1]
- assert not support.is_operator_supported(op)
-
-
def test_constraint_stridedslice_stride_values():
# Unsupported strides
op = create_strided_slice()