diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-12-22 10:40:51 +0100 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-12-22 14:25:05 +0000 |
commit | 3d73717f793100ba6705441fb42514f938780c1e (patch) | |
tree | 95e6072f8673655fb6f8774dfb70541e34967f54 /ethosu | |
parent | cc6915ce439a985e5f36e5ebc317c2a3f8bf9ce3 (diff) | |
download | ethos-u-vela-3d73717f793100ba6705441fb42514f938780c1e.tar.gz |
MLBEDSW-3791 Fix converting axis to 4D axis
Fix converting axis to 4D axis.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I83501494738f402b374efd8a369e5001f17b8152
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 0754f7e1..c3216785 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -73,10 +73,14 @@ def rewrite_concat(tens, arch, nng): tens.ops = [] offset = 0 for idx, inp in enumerate(inputs): + if axis >= 0: + axis_4D = axis + (4 - len(inp.shape)) + else: + axis_4D = axis new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx)) new_op.inputs = [inp] new_op.outputs = [tens] - new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape)) + new_op.attrs["concat_axis"] = axis_4D new_op.attrs["concat_start"] = offset offset += inp.shape[axis] new_op.attrs["concat_end"] = offset @@ -122,7 +126,10 @@ def rewrite_split(tens, arch, nng): for idx, out in enumerate(outputs): if out == tens: break - axis_4D = axis + (4 - len(out.shape)) + if axis >= 0: + axis_4D = axis + (4 - len(out.shape)) + else: + axis_4D = axis offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D) |