aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-22 10:40:51 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-22 14:25:05 +0000
commit3d73717f793100ba6705441fb42514f938780c1e (patch)
tree95e6072f8673655fb6f8774dfb70541e34967f54 /ethosu
parentcc6915ce439a985e5f36e5ebc317c2a3f8bf9ce3 (diff)
downloadethos-u-vela-3d73717f793100ba6705441fb42514f938780c1e.tar.gz
MLBEDSW-3791 Fix converting axis to 4D axis
Fix converting axis to 4D axis. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I83501494738f402b374efd8a369e5001f17b8152
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/graph_optimiser.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 0754f7e1..c3216785 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -73,10 +73,14 @@ def rewrite_concat(tens, arch, nng):
tens.ops = []
offset = 0
for idx, inp in enumerate(inputs):
+ if axis >= 0:
+ axis_4D = axis + (4 - len(inp.shape))
+ else:
+ axis_4D = axis
new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx))
new_op.inputs = [inp]
new_op.outputs = [tens]
- new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape))
+ new_op.attrs["concat_axis"] = axis_4D
new_op.attrs["concat_start"] = offset
offset += inp.shape[axis]
new_op.attrs["concat_end"] = offset
@@ -122,7 +126,10 @@ def rewrite_split(tens, arch, nng):
for idx, out in enumerate(outputs):
if out == tens:
break
- axis_4D = axis + (4 - len(out.shape))
+ if axis >= 0:
+ axis_4D = axis + (4 - len(out.shape))
+ else:
+ axis_4D = axis
offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)