aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/extract_npu_subgraphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/extract_npu_subgraphs.py')
-rw-r--r--ethosu/vela/extract_npu_subgraphs.py253
1 files changed, 253 insertions, 0 deletions
diff --git a/ethosu/vela/extract_npu_subgraphs.py b/ethosu/vela/extract_npu_subgraphs.py
new file mode 100644
index 00000000..5b9ba8b0
--- /dev/null
+++ b/ethosu/vela/extract_npu_subgraphs.py
@@ -0,0 +1,253 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Description:
+# Vela separates CPU operations and NPU operations into separate internal subgraphs. The CPU operations are left
+# untouched in the final output.
+#
+# Vela does this by identifying NPU passes and pulling them out from the main CPU graph into separate subgraphs, invoked
+# by NpuOp operations. Later, Vela generates command streams and compressed weight streams for the NPU subgraphs and
+# attaches them to the NpuOp. This encapsulates everything the NPU subgraph is supposed to do.
+
+from .nn_graph import Pass, PassPlacement, NpuBlockType, Subgraph
+from .operation import Operation
+import numpy as np
+
+
+def make_npu_call_op_pass(npu_subgraph):
+ op = Operation("NpuOp", "call_" + npu_subgraph.name)
+ op.attrs["subgraph"] = npu_subgraph
+ ps = Pass(op.name, PassPlacement.MemoryOnly, False, NpuBlockType.Default)
+ ps.ops = [op]
+ ps.primary_op = op
+ op.attrs["npu_block_type"] = ps.npu_block_type
+ op.scheduled_pass = ps
+
+ # Inputs and outputs filled in later as we cut the graphs
+ return ps
+
+
+def switch_tensor_for_op(op, orig_tens, new_tens):
+
+ op.inputs = [new_tens if tens == orig_tens else tens for tens in op.inputs]
+ op.outputs = [new_tens if tens == orig_tens else tens for tens in op.outputs]
+
+ ps = op.scheduled_pass
+ if ps is None:
+ return
+
+ ps.inputs = [new_tens if tens == orig_tens else tens for tens in ps.inputs]
+ ps.outputs = [new_tens if tens == orig_tens else tens for tens in ps.outputs]
+
+ if ps.ifm_tensor == orig_tens:
+ ps.ifm_tensor = new_tens
+ if ps.ifm2_tensor == orig_tens:
+ ps.ifm2_tensor = new_tens
+ if ps.ofm_tensor == orig_tens:
+ ps.ofm_tensor = new_tens
+ if ps.weight_tensor == orig_tens:
+ ps.weight_tensor = new_tens
+ if ps.scale_tensor == orig_tens:
+ ps.scale_tensor = new_tens
+
+
+def rewrite_tensor_cpu_producer_npu_consumers(
+ orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
+):
+ is_const = orig_tens.ops[0].type == "Const"
+
+ new_tens = orig_tens.clone("_npu")
+ orig_tens.npu_tensor = new_tens
+ new_tens.cpu_tensor = orig_tens
+
+ op_type = "SubgraphInput"
+ if is_const:
+ op_type = "Const"
+ op = Operation(op_type, orig_tens.name + "_input")
+ op.attrs["npu_block_type"] = NpuBlockType.Default
+ op.outputs = [new_tens]
+ op.scheduled_pass = startup_init_ps
+ new_tens.ops = [op]
+ startup_init_ps.ops.append(op)
+ startup_init_ps.outputs.append(new_tens)
+
+ if not is_const:
+ call_ps.inputs.append(orig_tens)
+ call_ps.primary_op.inputs.append(orig_tens)
+
+ for op in list(orig_tens.consumers()):
+ if op is None:
+ continue # Subgraph consumers handled separately.
+ ps = op.scheduled_pass
+ if subgraph_for_pass[ps] == npu_subgraph:
+ switch_tensor_for_op(op, orig_tens, new_tens)
+ orig_tens.consumer_list.remove(op)
+ new_tens.consumer_list.append(op)
+
+ # Deal with output tensors for the NPU graph. These are special.
+ npu_subgraph.output_tensors = [new_tens if tens == orig_tens else tens for tens in npu_subgraph.output_tensors]
+
+
+def rewrite_tensor_npu_producer_cpu_consumers(
+ orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
+):
+
+ new_tens = orig_tens.clone("_cpu")
+ new_tens.npu_tensor = orig_tens
+ orig_tens.cpu_tensor = new_tens
+
+ npu_subgraph.output_tensors.append(orig_tens)
+
+ call_ps.outputs.append(new_tens)
+ call_ps.primary_op.outputs.append(new_tens)
+ new_tens.ops = [call_ps.primary_op]
+
+ for op in list(orig_tens.consumers()):
+ if op is None:
+ continue # Subgraph consumers handled separately.
+ ps = op.scheduled_pass
+ if subgraph_for_pass[ps] != npu_subgraph:
+ switch_tensor_for_op(op, orig_tens, new_tens)
+ orig_tens.consumer_list.remove(op)
+ new_tens.consumer_list.append(op)
+
+ # Deal with output tensors for the CPU graph. These are special.
+ cpu_subgraph.output_tensors = [new_tens if tens == orig_tens else tens for tens in cpu_subgraph.output_tensors]
+
+
+def extract_subgraph(nng, orig_sg, arch):
+ assert orig_sg.placement == PassPlacement.Cpu
+
+ passes = list(orig_sg.passes)
+ place_vec = np.array([ps.placement for ps in passes])
+ place_vec[
+ place_vec == PassPlacement.StartupInit
+ ] = PassPlacement.Cpu # Keep the startup init pass on the CPU, we'll make new ones to move onto NPU.
+
+ # MemoryOnly passes that are either squeezed between NPU passes or on the boundary of NPU and CPU
+ # passes should be assigned to the NPU.
+
+ # Forward, then backwards
+ for is_reversed in range(2):
+ last_place = PassPlacement.Cpu
+ seq = enumerate(place_vec)
+ if is_reversed:
+ seq = reversed(list(seq))
+ for idx, place in seq:
+ if place == PassPlacement.MemoryOnly:
+ if last_place == PassPlacement.Npu:
+ place = PassPlacement.Npu
+ place_vec[idx] = place
+
+ if place != PassPlacement.MemoryOnly:
+ last_place = place
+
+ # Anything left, assign to the CPU.
+ place_vec[place_vec == PassPlacement.MemoryOnly] = PassPlacement.Cpu
+
+ if np.all(place_vec == PassPlacement.Cpu):
+ return [] # Nothing to do
+
+ # Create the subgraphs and split passes between them
+
+ new_subgraphs = []
+ split_count = 0
+ subgraph_for_pass = {}
+ orig_sg.passes = []
+ call_pass = {}
+ startup_init_passes = {}
+
+ last_place = PassPlacement.Cpu
+ curr_sg = orig_sg
+
+ for idx, place in enumerate(place_vec):
+ if place != last_place:
+ if place == PassPlacement.Npu:
+ split_count += 1
+ curr_sg = Subgraph("%s_split_%d" % (orig_sg.name, split_count), PassPlacement.Npu)
+ new_subgraphs.append(curr_sg)
+ call_ps = make_npu_call_op_pass(curr_sg)
+ subgraph_for_pass[call_ps] = orig_sg
+ orig_sg.passes.append(call_ps)
+ call_pass[curr_sg] = call_ps
+
+ startup_init_ps = Pass(
+ curr_sg.name + "_startup_init", PassPlacement.StartupInit, False, NpuBlockType.Default
+ )
+ curr_sg.passes.append(startup_init_ps)
+ startup_init_passes[curr_sg] = startup_init_ps
+ subgraph_for_pass[startup_init_ps] = curr_sg
+
+ else:
+ curr_sg = orig_sg
+ last_place = place
+ ps = passes[idx]
+ subgraph_for_pass[ps] = curr_sg
+ curr_sg.passes.append(ps)
+
+ # Rewrite tensors to fix up graphs.
+
+ for curr_sg in new_subgraphs:
+ for ps in curr_sg.passes:
+ for tens in ps.inputs:
+ source_sgs = [subgraph_for_pass[op.scheduled_pass] for op in tens.ops]
+ assert len(source_sgs) >= 0
+ producer_sg = source_sgs[0]
+ for sg in source_sgs:
+ assert sg == producer_sg # All need to be the same.
+
+ if producer_sg != curr_sg:
+ assert (
+ producer_sg == orig_sg
+ ) # Because we go in-order, all the producers must be the original graph.
+ rewrite_tensor_cpu_producer_npu_consumers(
+ tens, call_pass[curr_sg], startup_init_passes[curr_sg], curr_sg, orig_sg, subgraph_for_pass
+ )
+
+ for tens in ps.outputs:
+
+ dest_sgs = [subgraph_for_pass[op.scheduled_pass] for op in tens.consumers() if op is not None]
+ need_rewrite = False
+ for sg in dest_sgs:
+ if sg != curr_sg:
+ need_rewrite = True
+ break
+ if tens in orig_sg.output_tensors:
+ need_rewrite = True
+
+ if need_rewrite:
+ rewrite_tensor_npu_producer_cpu_consumers(
+ tens, call_pass[curr_sg], startup_init_passes[curr_sg], curr_sg, orig_sg, subgraph_for_pass
+ )
+
+ return new_subgraphs
+
+
+def extract_npu_subgraphs(nng, arch):
+
+ nng.refresh_after_modification()
+
+ for sg in list(nng.subgraphs):
+ if sg.placement == PassPlacement.Cpu:
+ new_subgraphs = extract_subgraph(nng, sg, arch)
+ nng.subgraphs += new_subgraphs
+
+ nng.refresh_after_modification()
+ nng.prune_startup_init_pass()
+
+ for sg in nng.subgraphs:
+ sg.build_pass_links()