aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 0754f7e1..c3216785 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -73,10 +73,14 @@ def rewrite_concat(tens, arch, nng):
tens.ops = []
offset = 0
for idx, inp in enumerate(inputs):
+ if axis >= 0:
+ axis_4D = axis + (4 - len(inp.shape))
+ else:
+ axis_4D = axis
new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
new_op.inputs = [inp]
new_op.outputs = [tens]
- new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape))
+ new_op.attrs["concat_axis"] = axis_4D
new_op.attrs["concat_start"] = offset
offset += inp.shape[axis]
new_op.attrs["concat_end"] = offset
@@ -122,7 +126,10 @@ def rewrite_split(tens, arch, nng):
for idx, out in enumerate(outputs):
if out == tens:
break
- axis_4D = axis + (4 - len(out.shape))
+ if axis >= 0:
+ axis_4D = axis + (4 - len(out.shape))
+ else:
+ axis_4D = axis
offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)