diff options
author | Henrik G Olsson <henrik.olsson@arm.com> | 2021-04-12 14:53:18 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2021-04-16 15:37:05 +0000 |
commit | a93f02629f9d69257f2b0010a7d6db12f965e889 (patch) | |
tree | a00d406f149e459536a2934fcc41cb9a55f1bfbe /ethosu | |
parent | e22f96b9d1a1a87398b901d34cd52f11a5647f96 (diff) | |
download | ethos-u-vela-a93f02629f9d69257f2b0010a7d6db12f965e889.tar.gz |
MLBEDSW-4132 Fix off-by-one error for negative packing axis
Also applies to unpack.
Signed-off-by: Henrik G Olsson <henrik.olsson@arm.com>
Change-Id: I07e7083aeb6aefd6e26f9d134b858080f28f1719
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index f59b685a..642f1349 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -95,12 +95,12 @@ def rewrite_concat_ops(op, arch): if op.type == Op.Pack: # Pack is also referred to as Stack axis = int(op.attrs["axis"]) + if axis < 0: # Convert to positive axis + axis = len(op.inputs[0].shape) + 1 + axis + desired_shape = op.inputs[0].shape[:axis] + [1] + op.inputs[0].shape[axis:] - if axis >= 0: - axis_4D = axis + (4 - len(desired_shape)) - else: - axis_4D = axis + axis_4D = axis + (4 - len(desired_shape)) for idx, inp in enumerate(op.inputs): op.ifm_shapes[idx] = Shape4D(desired_shape) @@ -651,20 +651,16 @@ def rewrite_unpack_output(op, arch, nng): if op.run_on_npu and op.type == Op.Unpack: # Unpack is also referred to as Unstack axis = int(op.attrs["axis"]) + if axis < 0: # Convert to positive axis + axis = len(op.inputs[0].shape) + 1 + axis op.type = Op.UnpackReshaped desired_output_shape = tens.shape[:axis] + [1] + tens.shape[axis:] - if axis >= 0: - axis_4D = axis + (4 - len(desired_output_shape)) - else: - axis_4D = axis + axis_4D = axis + (4 - len(desired_output_shape)) + op.attrs["split_axis_4D"] = [axis_4D] * len(op.outputs) - axis_4D_list = [0] * len(op.outputs) for idx, out_tens in enumerate(op.outputs): op.ofm_shapes[idx] = Shape4D(desired_output_shape) - axis_4D_list[idx] = axis_4D - - op.attrs["split_axis_4D"] = axis_4D_list return op |