From c6ac1944d9934faf6d22825cdd3273afe55432a4 Mon Sep 17 00:00:00 2001 From: Dwight Lidman Date: Fri, 2 Oct 2020 14:55:45 +0200 Subject: MLBEDSW-3004: UnpackReshaped can't be serialised This commit fixes a bug where a rewritten Unpack operator is placed on the CPU and crashes Vela during serialisation due to the type having changed and there not being a mapping for the modified op type. The solution is to move the fixup_unpack_output function to the graph optimisation pass B, allowing the supported op check to run before it. Signed-off-by: Dwight Lidman Change-Id: Ic6bd4c70a478fd61adf377cb487f5b9253130314 --- ethosu/vela/graph_optimiser.py | 8 +++++--- ethosu/vela/operation.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index d4423524..f6209ed2 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -421,7 +421,7 @@ def unfuse_activation_function(op, arch, nng): def fixup_unpack_output(tens, arch, nng): op = tens.ops[0] - if op.type in set((Op.Unpack, Op.StridedSlice)): + if op.run_on_npu and op.type in set((Op.Unpack, Op.StridedSlice)): # Unpack is also referred to as Unstack # Requires the rewrite_split function to be called on the op afterwards @@ -1061,7 +1061,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_unpack_output], op_rewrite_list, rewrite_unsupported=False + nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False, ) for idx, sg in enumerate(nng.subgraphs): @@ -1081,7 +1081,9 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass - nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [rewrite_concat, rewrite_split], []) + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [fixup_unpack_output, rewrite_concat, rewrite_split], [] + ) if verbose_graph: nng.print_graph() diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 710511c6..6e5b4820 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -314,7 +314,7 @@ def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True): class Operation: """Class representing a Neural Network operation. Has a name, a type, -input and output tensors, as well as an attribute dictionary.""" + input and output tensors, as well as an attribute dictionary.""" __slots__ = ( "type", -- cgit v1.2.1