diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 0754f7e1..c3216785 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -73,10 +73,14 @@ def rewrite_concat(tens, arch, nng): tens.ops = [] offset = 0 for idx, inp in enumerate(inputs): + if axis >= 0: + axis_4D = axis + (4 - len(inp.shape)) + else: + axis_4D = axis new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx)) new_op.inputs = [inp] new_op.outputs = [tens] - new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape)) + new_op.attrs["concat_axis"] = axis_4D new_op.attrs["concat_start"] = offset offset += inp.shape[axis] new_op.attrs["concat_end"] = offset @@ -122,7 +126,10 @@ def rewrite_split(tens, arch, nng): for idx, out in enumerate(outputs): if out == tens: break - axis_4D = axis + (4 - len(out.shape)) + if axis >= 0: + axis_4D = axis + (4 - len(out.shape)) + else: + axis_4D = axis offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D) |