diff options
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 16 | ||||
-rw-r--r-- | ethosu/vela/supported_operators.py | 17 | ||||
-rw-r--r-- | ethosu/vela/test/test_supported_operators.py | 6 |
3 files changed, 4 insertions, 35 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 32f97d2f..e31348b5 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -422,19 +422,12 @@ def unfuse_activation_function(op, arch, nng): def fixup_stridedslice_output(tens, arch, nng): op = tens.ops[0] - if op.type == Op.StridedSlice: + if op.run_on_npu and op.type == Op.StridedSlice: reshape_input_shape = tens.shape new_axis_mask = op.attrs["new_axis_mask"] shrink_axis_mask = op.attrs["shrink_axis_mask"] - ellipsis_mask = op.attrs["ellipsis_mask"] - if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0: - # Not supported, will be put on CPU - return tens - if shrink_axis_mask == 0 and new_axis_mask == 0: - # Equal Rank StridedSlice, no need to insert reshape - return tens - elif shrink_axis_mask != 0: + if shrink_axis_mask != 0: n = 0 axis = 0 while shrink_axis_mask: @@ -446,7 +439,6 @@ def fixup_stridedslice_output(tens, arch, nng): assert len(tens.shape) == (len(op.inputs[0].shape) - n) op.attrs["shrink_axis_mask"] = 0 - elif new_axis_mask != 0: n = 0 axis = 0 @@ -1092,7 +1084,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_stridedslice_output], op_rewrite_list, rewrite_unsupported=False, + nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False, ) for idx, sg in enumerate(nng.subgraphs): @@ -1113,7 +1105,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_unpack_output, rewrite_concat, rewrite_split], [] + nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [] ) if verbose_graph: diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 04cda1da..3e649e09 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -219,7 +219,6 @@ class SupportedOperators: # StridedSlice specific checks: self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_input_count) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_inputs_const) - self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_tens_size_matches) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_stridedslice_stride_values) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_ellipsis_mask) self.specific_constraints[Op.StridedSlice].append(SupportedOperators.constraint_axis_masks) @@ -728,22 +727,6 @@ class SupportedOperators: return valid, f"Op has non-constant tensors: {extra}" @staticmethod - def constraint_stridedslice_tens_size_matches(op): - "All Input sizes must match OFM size" - ifm, begin, end, strides = op.inputs - ifm_size = len(ifm.shape) - ofm_size = len(op.ofm.shape) - begin_size = len(begin.values) - end_size = len(end.values) - strides_size = len(strides.values) - valid = ifm_size == ofm_size == begin_size == end_size == strides_size - extra = ( - f"Op has ofm_size={ofm_size}, ifm_size={ifm_size}," - f" begin_size={begin_size}, end_size={end_size} and strides_size={strides_size}" - ) - return valid, extra - - @staticmethod def constraint_stridedslice_stride_values(op): "All Strides values must be 1" strides = op.inputs[3] diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 595ea590..245ebcf9 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -486,12 +486,6 @@ def test_constraint_stridedslice_inputs_const(): assert not support.is_operator_supported(op) -def test_constraint_stridedslice_tens_size_matches(): - op = create_strided_slice() - op.inputs[1].values = [1, 1, 1, 1, 1, 1, 1, 1] - assert not support.is_operator_supported(op) - - def test_constraint_stridedslice_stride_values(): # Unsupported strides op = create_strided_slice() |