diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 16 |
1 files changed, 4 insertions, 12 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 32f97d2f..e31348b5 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -422,19 +422,12 @@ def unfuse_activation_function(op, arch, nng): def fixup_stridedslice_output(tens, arch, nng): op = tens.ops[0] - if op.type == Op.StridedSlice: + if op.run_on_npu and op.type == Op.StridedSlice: reshape_input_shape = tens.shape new_axis_mask = op.attrs["new_axis_mask"] shrink_axis_mask = op.attrs["shrink_axis_mask"] - ellipsis_mask = op.attrs["ellipsis_mask"] - if (new_axis_mask != 0 and shrink_axis_mask != 0) or ellipsis_mask != 0: - # Not supported, will be put on CPU - return tens - if shrink_axis_mask == 0 and new_axis_mask == 0: - # Equal Rank StridedSlice, no need to insert reshape - return tens - elif shrink_axis_mask != 0: + if shrink_axis_mask != 0: n = 0 axis = 0 while shrink_axis_mask: @@ -446,7 +439,6 @@ def fixup_stridedslice_output(tens, arch, nng): assert len(tens.shape) == (len(op.inputs[0].shape) - n) op.attrs["shrink_axis_mask"] = 0 - elif new_axis_mask != 0: n = 0 axis = 0 @@ -1092,7 +1084,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_stridedslice_output], op_rewrite_list, rewrite_unsupported=False, + nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False, ) for idx, sg in enumerate(nng.subgraphs): @@ -1113,7 +1105,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_unpack_output, rewrite_concat, rewrite_split], [] + nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [] ) if verbose_graph: |