aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHenrik G Olsson <henrik.olsson@arm.com>2021-04-12 14:53:18 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-04-16 15:37:05 +0000
commita93f02629f9d69257f2b0010a7d6db12f965e889 (patch)
treea00d406f149e459536a2934fcc41cb9a55f1bfbe
parente22f96b9d1a1a87398b901d34cd52f11a5647f96 (diff)
downloadethos-u-vela-a93f02629f9d69257f2b0010a7d6db12f965e889.tar.gz
MLBEDSW-4132 Fix off-by-one error for negative packing axis
Also applies to unpack. Signed-off-by: Henrik G Olsson <henrik.olsson@arm.com> Change-Id: I07e7083aeb6aefd6e26f9d134b858080f28f1719
-rw-r--r--ethosu/vela/graph_optimiser.py20
1 files changed, 8 insertions, 12 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index f59b685a..642f1349 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -95,12 +95,12 @@ def rewrite_concat_ops(op, arch):
if op.type == Op.Pack:
# Pack is also referred to as Stack
axis = int(op.attrs["axis"])
+ if axis < 0: # Convert to positive axis
+ axis = len(op.inputs[0].shape) + 1 + axis
+
desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:]
- if axis >= 0:
- axis_4D = axis + (4 - len(desired_shape))
- else:
- axis_4D = axis
+ axis_4D = axis + (4 - len(desired_shape))
for idx, inp in enumerate(op.inputs):
op.ifm_shapes[idx] = Shape4D(desired_shape)
@@ -651,20 +651,16 @@ def rewrite_unpack_output(op, arch, nng):
if op.run_on_npu and op.type == Op.Unpack:
# Unpack is also referred to as Unstack
axis = int(op.attrs["axis"])
+ if axis < 0: # Convert to positive axis
+ axis = len(op.inputs[0].shape) + 1 + axis
op.type = Op.UnpackReshaped
desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:]
- if axis >= 0:
- axis_4D = axis + (4 - len(desired_output_shape))
- else:
- axis_4D = axis
+ axis_4D = axis + (4 - len(desired_output_shape))
+ op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs)
- axis_4D_list = [0] * len(op.outputs)
for idx, out_tens in enumerate(op.outputs):
op.ofm_shapes[idx] = Shape4D(desired_output_shape)
- axis_4D_list[idx] = axis_4D
-
- op.attrs["split_axis_4D"] = axis_4D_list
return op