aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/insert_dma.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/insert_dma.py')
-rw-r--r--ethosu/vela/insert_dma.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py
new file mode 100644
index 00000000..b63c1ea1
--- /dev/null
+++ b/ethosu/vela/insert_dma.py
@@ -0,0 +1,60 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Description:
+# Insert DMA operations into the graph for transfering weights.
+
+from .nn_graph import Operation, MemArea, TensorPurpose, NpuBlockType
+from . import rewrite_graph
+
+
+def insert_dma_cmd(op, arch):
+ if op.type == "DMA":
+ return op # Already rewritten
+ for idx, tens in enumerate(op.inputs):
+
+ if tens.mem_area in (MemArea.Dram, MemArea.OffChipFlash) and tens.mem_area != arch.fast_storage_mem_area:
+ if tens.purpose == TensorPurpose.Weights:
+ only_vector_product_consumers = True
+ for oper in tens.consumers():
+ if oper is None or oper.attrs.get("npu_block_type") != NpuBlockType.VectorProduct:
+ only_vector_product_consumers = False
+ break
+
+ # Tensor products has no need for DMA, tensors are only read once and can be in flash.
+ # Other operations re-reads tensors, this is better done from SRAM.
+ if not only_vector_product_consumers:
+ # Insert a DMA command here, as well as a new tensor situated in SRAM of the same size.
+ new_tens = tens.clone_into_fast_storage(arch)
+ dma_cmd = Operation("DMA", tens.ops[0].name + "_dma")
+ dma_cmd.inputs = [tens]
+ dma_cmd.outputs = [new_tens]
+ dma_cmd.attrs["source"] = tens.mem_area
+ dma_cmd.attrs["destination"] = new_tens.mem_area
+ dma_cmd.run_on_npu = True
+ new_tens.ops = [dma_cmd]
+ op.inputs[idx] = new_tens
+ return op
+
+
+def insert_dma_commands(nng, arch, verbose_graph=False):
+
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(sg, arch, [], [insert_dma_cmd])
+ if verbose_graph:
+ nng.print_graph()
+ return nng