diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index f4472f9e..d2598aec 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -141,6 +141,7 @@ def rewrite_split_ops(tens, arch, nng): new_op = Operation(Op.SplitSliceRead, split_op.name) new_op.inputs = [inp] ofm_shape_idx = 0 + read_shape = offset_end # For Split the offset cannot be extracted from the tensor so it has to # be calculated from the index of the output tensor @@ -160,11 +161,13 @@ def rewrite_split_ops(tens, arch, nng): if out == tens: ofm_shape_idx = idx + read_shape = split_op.ofm_shapes[idx] break offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D] new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0) + new_op.read_shapes[0] = read_shape new_op.run_on_npu = True new_op.set_output_tensor(tens) new_op.ifm_shapes.append(Shape4D(inp.shape)) @@ -189,10 +192,12 @@ def remove_SplitSliceRead(op, arch): cons_op = op.ofm.consumer_list[0] if cons_op.ifm == op.ofm: cons_op.read_offsets[0] = op.read_offsets[0] + cons_op.read_shapes[0] = op.read_shapes[0] cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[0]) cons_op.ifm_shapes[0] = op.ifm_shapes[0] elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == op.ofm: cons_op.read_offsets[1] = op.read_offsets[0] + cons_op.read_shapes[1] = op.read_shapes[0] cons_op.set_input_tensor(op.ifm, cons_op.type.info.indices.ifms[1]) cons_op.ifm_shapes[1] = op.ifm_shapes[0] @@ -212,6 +217,7 @@ def remove_SplitSliceRead(op, arch): avgpool_op.ifm_shapes.append(op.ifm_shapes[0]) avgpool_op.ofm_shapes.append(op.ofm_shapes[0]) avgpool_op.read_offsets[0] = op.read_offsets[0] + avgpool_op.read_shapes[0] = op.read_shapes[0] op.ifm.consumer_list.remove(op) DebugDatabase.add_optimised(op, avgpool_op) |