From 3d73717f793100ba6705441fb42514f938780c1e Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Tue, 22 Dec 2020 10:40:51 +0100 Subject: MLBEDSW-3791 Fix converting axis to 4D axis Fix converting axis to 4D axis. Signed-off-by: Patrik Gustavsson Change-Id: I83501494738f402b374efd8a369e5001f17b8152 --- ethosu/vela/graph_optimiser.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 0754f7e1..c3216785 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -73,10 +73,14 @@ def rewrite_concat(tens, arch, nng): tens.ops = [] offset = 0 for idx, inp in enumerate(inputs): + if axis >= 0: + axis_4D = axis + (4 - len(inp.shape)) + else: + axis_4D = axis new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx)) new_op.inputs = [inp] new_op.outputs = [tens] - new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape)) + new_op.attrs["concat_axis"] = axis_4D new_op.attrs["concat_start"] = offset offset += inp.shape[axis] new_op.attrs["concat_end"] = offset @@ -122,7 +126,10 @@ def rewrite_split(tens, arch, nng): for idx, out in enumerate(outputs): if out == tens: break - axis_4D = axis + (4 - len(out.shape)) + if axis >= 0: + axis_4D = axis + (4 - len(out.shape)) + else: + axis_4D = axis offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D) -- cgit v1.2.1