From 224e99bd70a443e345d3ea454aedc51bf46cf261 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 14 Jan 2021 10:55:43 +0100 Subject: MLBEDSW-3654 Fix for split/concat ops Fix for split/concat ops - set correct ifm_shapes in pass packing Signed-off-by: Patrik Gustavsson Change-Id: I7373b1743e4511b6c1dfaa398b927fbb1b454f60 --- ethosu/vela/graph_optimiser.py | 6 ++++-- ethosu/vela/pass_packing.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 6d2696c4..511ac954 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -106,7 +106,7 @@ def rewrite_concat(tens, arch, nng): def rewrite_split(tens, arch, nng): - if len(tens.ops) == 1 and tens.ops[0].type.is_split_op(): + if len(tens.ops) == 1 and tens.ops[0].type.is_split_op() and tens.ops[0].type != Op.Unpack: split_op = tens.ops[0] # Not supported so leave it and run on CPU @@ -125,6 +125,7 @@ def rewrite_split(tens, arch, nng): # Get the start and end of the split offset_start = [0] * 4 for idx, out in enumerate(outputs): + split_op.ofm_shapes[idx] = Shape4D(out.shape) if out == tens: break if axis >= 0: @@ -143,7 +144,8 @@ def rewrite_split(tens, arch, nng): new_op.attrs["split_start"] = offset_start new_op.run_on_npu = True new_op.set_output_tensor(tens) - new_op.set_ifm_ofm_shapes() + new_op.ifm_shapes.append(Shape4D(inp.shape)) + new_op.ofm_shapes.append(Shape4D(full_shape(4, tens.shape, 1))) DebugDatabase.add_optimised(split_op, new_op) return tens diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 7abf3b24..ee0d7128 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -244,6 +244,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): input_set = set() ifm_tensor = None primary_op = None + ifm_shapes = None to_process = collections.deque() for start_op in start_ops_to_process: @@ -279,6 +280,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): ): assert len(curr_op.inputs) >= 1 ifm_tensor = curr_op.ifm + ifm_shapes = curr_op.ifm_shapes.copy() assert ifm_tensor is not None, "IFM missing in {}".format(curr_op) assert ifm_tensor.purpose == TensorPurpose.FeatureMap @@ -417,7 +419,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): ps.ifm_tensor = ifm_tensor ps.ifm2_tensor = None if ps.primary_op is not None and ps.primary_op.run_on_npu: - ps.ifm_shapes.append(ps.primary_op.ifm_shapes[0]) + ps.ifm_shapes.append(ifm_shapes[0]) ps.ofm_tensor = ofm_tensor ps.ofm_shapes.append(ofm_shape) -- cgit v1.2.1