aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/nn_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r--ethosu/vela/nn_graph.py548
1 files changed, 548 insertions, 0 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
new file mode 100644
index 00000000..8d335bd8
--- /dev/null
+++ b/ethosu/vela/nn_graph.py
@@ -0,0 +1,548 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Description:
+# Neural network graph classes and enums.
+# Pass - A packed pass containing one or more Operations.
+# CascadedPass - A scheduled pass containing one or more Passes, as well as a scheduling strategy and block
+# configurations.
+# Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
+# Graph - A full neural network graph with one or more Subgraphs.
+
+import enum
+from .data_type import BaseType, DataType
+from .tensor import MemArea, TensorPurpose, TensorSubPurpose, TensorFormat, Tensor
+from .operation import Operation, NpuBlockType
+
+
+class PassPlacement(enum.Enum):
+ Unknown = 0
+ Cpu = 1
+ Npu = 2
+ MemoryOnly = 3
+ StartupInit = 4
+
+
+class TensorAllocator(enum.Enum):
+ LinearAlloc = 1
+ Greedy = 2
+
+ def __str__(self):
+ return self.name
+
+
+class Pass:
+ def __init__(self, name, placement, is_element_wise, npu_block_type):
+ self.inputs = []
+ self.intermediates = []
+ self.outputs = []
+ self.ops = []
+ self.primary_op = None
+ self.ifm_tensor = None
+ self.ifm2_tensor = None
+ self.ofm_tensor = None
+ self.weight_tensor = None
+ self.scale_tensor = None
+ self.name = name
+ self.cascade = None
+ self.placement = placement
+
+ # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
+ # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
+ self.is_element_wise = is_element_wise
+ self.npu_block_type = npu_block_type
+ self.block_config = None # will be filled in by scheduler
+ self.shared_buffer = None # will be filled in by scheduler
+
+ self.predecessors = []
+ self.successors = []
+
+ def __str__(self):
+ return "<nng.Pass '%s', %s, ops=%s>" % (self.name, self.placement, [op.type for op in self.ops])
+
+ __repr__ = __str__
+
+ def get_primary_op_ifm_weights(self):
+ if not self.primary_op:
+ return None, None
+ return self.primary_op.get_ifm_ifm2_weights_ofm()[::2]
+
+ def get_primary_op_ifm_ifm2_weights_ofm(self):
+ if not self.primary_op:
+ return None, None, None, None
+ return self.primary_op.get_ifm_ifm2_weights_ofm()
+
+ def get_primary_op_ifm_weights_biases_ofm(self):
+ if not self.primary_op:
+ return None, None, None, None
+ return self.primary_op.get_ifm_weights_biases_ofm()
+
+
+class SchedulingStrategy(enum.Enum):
+ Unknown = -1
+ IfmStream = 0
+ WeightStream = 1
+
+
+class SchedulerRewrite(enum.Enum):
+ Nop = 0
+ ChangeTensorSubPurpose = 1
+
+
+class CascadedPass:
+ def __init__(self, name, strat, inputs, intermediates, outputs, passes, placement, is_element_wise):
+ self.name = name
+ self.strategy = strat
+ self.inputs = inputs
+ self.intermediates = intermediates
+ self.outputs = outputs
+ self.passes = passes
+ self.placement = placement
+ self.is_element_wise = is_element_wise
+
+ self.predecessors = []
+ self.successors = []
+
+ def __str__(self):
+ return "<nng.CascadedPass strategy=%s x %s '%s', passes=%s, block_configs=%s>" % (
+ self.strategy,
+ len(self.passes),
+ self.name,
+ [ps.name for ps in self.passes],
+ [ps.block_config for ps in self.passes],
+ )
+
+ __repr__ = __str__
+
+
+class Subgraph:
+ def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
+ self.output_tensors = []
+ self.input_tensors = []
+ self.original_inputs = [] # Preserve the original input order
+ self.passes = []
+ self.cascaded_passes = []
+ self.name = name
+ self.high_level_command_stream = []
+ self.placement = placement
+ self.command_stream_tensor = None
+ self.flash_tensor = None
+
+ self.memory_used = {}
+
+ def __str__(self):
+ return "<nng.Subgraph '%s', n_passes=%d, n_cascaded_passes=%d>" % (
+ self.name,
+ len(self.passes),
+ len(self.cascaded_passes),
+ )
+
+ __repr__ = __str__
+
+ def update_consumers(self):
+ visit_op_set = set()
+ visit_tensor_set = set()
+ self.input_tensors = []
+
+ print_visit = False
+
+ def visit_op(op):
+ if op in visit_op_set:
+ return
+
+ visit_op_set.add(op)
+ for inp in op.inputs:
+ if print_visit:
+ print(inp, "adding consumer", op)
+ visit_tensor(inp)
+ inp.consumer_list.append(op)
+
+ if op.type in set(("Placeholder", "SubgraphInput")):
+ assert len(op.outputs) == 1
+ self.input_tensors.append(op.outputs[0])
+
+ for out in op.outputs:
+ if out not in visit_tensor_set:
+ out.consumer_list = [] # reset unvisited output, just in case
+
+ def visit_tensor(tens):
+ if tens in visit_tensor_set:
+ return
+ visit_tensor_set.add(tens)
+ tens.consumer_list = []
+ for op in tens.ops:
+ visit_op(op)
+
+ for ps in self.passes:
+ for tens in ps.outputs + ps.inputs:
+ tens.consumer_list = [] # reset unvisited tensors to start with
+
+ for tens in self.output_tensors:
+ visit_tensor(tens)
+ tens.consumer_list.append(None) # special op to indicate that the graph consumes the result
+
+ print_visit = True
+ for ps in self.passes:
+ for op in ps.ops:
+ visit_op(op)
+ for tens in ps.inputs:
+ visit_tensor(tens)
+
+ def build_pass_links(self):
+ for idx, ps in enumerate(self.passes):
+ ps.time = 2 * idx
+ ps.predecessors = []
+ ps.successors = []
+
+ for ps in self.passes:
+ for tens in ps.inputs:
+ for op in tens.ops:
+ pred_pass = op.scheduled_pass
+ assert pred_pass.time < ps.time
+ if ps not in pred_pass.successors:
+ pred_pass.successors.append(ps)
+
+ if pred_pass not in ps.predecessors:
+ ps.predecessors.append(pred_pass)
+
+ assert tens in pred_pass.outputs
+
+ def build_pass_dag_predecessors(self):
+ for ps in self.passes:
+ ps.dag_predecessors = []
+
+ class State(enum.Enum):
+ NotVisited = 0
+ BeingVisited = 1
+ Visited = 2
+
+ pass_visit_dict = {}
+
+ def visit_pass(ps):
+ state = pass_visit_dict.get(ps, State.NotVisited)
+ if state == State.Visited:
+ return True
+ elif state == State.BeingVisited:
+ return False # this is a loop, need to remove this link
+ elif state == State.NotVisited:
+ pass_visit_dict[ps] = State.BeingVisited
+
+ ps.dag_predecessors = []
+ for pred in ps.predecessors:
+ if visit_pass(pred):
+ ps.dag_predecessors.append(pred)
+
+ pass_visit_dict[ps] = State.Visited
+ return True
+
+ for ps in self.passes:
+ if not ps.successors:
+ visit_pass(ps)
+
+ def build_cascaded_pass_links(self):
+ for cps in self.cascaded_passes:
+ cps.predecessors = []
+ cps.successors = []
+
+ for cps in self.cascaded_passes:
+ for tens in cps.inputs:
+ for op in tens.ops:
+ pred_cpass = op.scheduled_pass.cascade
+ if cps not in pred_cpass.successors:
+ pred_cpass.successors.append(cps)
+
+ if pred_cpass not in cps.predecessors:
+ cps.predecessors.append(pred_cpass)
+
+ assert tens in pred_cpass.outputs
+
+ def refresh_after_modification(self):
+ self.update_consumers()
+
+ def prune_startup_init_pass(self):
+ assert len(self.passes) >= 1
+ ps = self.passes[0]
+ assert ps.placement == PassPlacement.StartupInit
+
+ ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
+ ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
+
+ def get_all_ops(self):
+ all_ops = []
+ visit_op_set = set()
+ visit_tensor_set = set()
+
+ def visit_op(op):
+ if op in visit_op_set:
+ return
+ visit_op_set.add(op)
+ for inp in op.inputs:
+ visit_tensor(inp)
+
+ all_ops.append(op)
+
+ def visit_tensor(tens):
+ if tens in visit_tensor_set:
+ return
+ visit_tensor_set.add(tens)
+ for op in tens.ops:
+ visit_op(op)
+
+ for tens in self.output_tensors:
+ visit_tensor(tens)
+
+ return all_ops
+
+ def print_operators(self):
+ all_ops = self.get_all_ops()
+ unique_ops = []
+ print("print_operators")
+ for op in all_ops:
+ if op.type in set(("Const", "Identity", "Placeholder")):
+ continue
+
+ attrs = op.attrs
+ if (
+ op.type == "Conv2D"
+ or op.type == "DepthwiseConv2dNative"
+ or op.type == "Conv2DBiasAct"
+ or op.type == "DepthwiseConv2dBiasAct"
+ ):
+ kshape = op.inputs[1].shape
+ attrs["kshape"] = [kshape[0], kshape[1]]
+ attrs["type"] = op.type
+ attrs.pop("use_cudnn_on_gpu", None)
+ if attrs not in unique_ops:
+ unique_ops.append(attrs)
+ # print attributes in human readable format
+ a = attrs.copy()
+ s = a.pop("type")
+ data_format = a.pop("data_format", None)
+ if data_format and data_format != b"NHWC":
+ s += " " + str(data_format)
+ t = a.pop("T", None)
+ if t:
+ s += " " + str(t)[9:-2]
+ srct = a.pop("SrcT", None)
+ if srct:
+ s += " " + str(srct)[9:-2]
+ dstt = a.pop("DstT", None)
+ if dstt:
+ s += "->" + str(dstt)[9:-2]
+ print(s + " " + str(a))
+
+ def print_graph(self):
+ all_ops = self.get_all_ops()
+ for idx, op in enumerate(all_ops):
+ print(idx, op.type, op.name)
+
+ def print_graph_with_tensors(self):
+ all_ops = self.get_all_ops()
+ for idx, op in enumerate(all_ops):
+ print(idx, op.type, op.name)
+ for idx, tens in enumerate(op.inputs):
+ print(" Input %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
+ for idx, tens in enumerate(op.outputs):
+ print(" Output %02d %20s %20s %s" % (idx, tens.purpose.name, tens.mem_area.name, tens))
+ print()
+
+ def print_graph_with_tensor_quantization(self):
+ all_ops = self.get_all_ops()
+ for idx, op in enumerate(all_ops):
+ print(idx, op.type, op.name)
+ for idx, tens in enumerate(op.inputs):
+ q = tens.quantization
+ if q is None:
+ print(" Input %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
+ else:
+ print(
+ " Input %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
+ % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
+ )
+ for idx, tens in enumerate(op.outputs):
+ q = tens.quantization
+ if q is None:
+ print(" Output %02d %10s NO QUANTIZATION INFO %s" % (idx, tens.dtype, tens.name))
+ else:
+ print(
+ " Output %02d %10s min=%s max=%s scale=%s zero_point=%s %s"
+ % (idx, tens.dtype, q.min, q.max, q.scale_f32, q.zero_point, tens.name)
+ )
+ print()
+
+ def print_passes(self):
+ for idx, ps in enumerate(self.passes):
+ print("%03d %s" % (idx * 2, ps))
+
+ def print_passes_with_tensors(self):
+ for idx, ps in enumerate(self.passes):
+ print("%3d %s" % (idx * 2, ps))
+ for idx, tens in enumerate(ps.inputs):
+ print(
+ " Input %2d %-15s %-15s %-15s %s"
+ % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
+ )
+ for idx, tens in enumerate(ps.intermediates):
+ print(
+ " Intermediate %2d %-15s %-15s %-15s %s"
+ % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
+ )
+ for idx, tens in enumerate(ps.outputs):
+ print(
+ " Output %2d %-15s %-15s %-15s %s"
+ % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
+ )
+ print()
+
+ def print_cascaded_passes(self):
+ for idx, ps in enumerate(self.cascaded_passes):
+ print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
+
+ def print_cascaded_passes_with_tensors(self):
+ for idx, ps in enumerate(self.cascaded_passes):
+ print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
+ for idx, tens in enumerate(ps.inputs):
+ print(
+ " Input %2d %-15s %-15s %-15s %s"
+ % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
+ )
+ for idx, tens in enumerate(ps.intermediates):
+ print(
+ " Intermediate %2d %-15s %-15s %-15s %s"
+ % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
+ )
+ for idx, tens in enumerate(ps.outputs):
+ print(
+ " Output %2d %-15s %-15s %-15s %s"
+ % (idx, tens.purpose.name, tens.mem_area.name, tens.format.name, tens.name)
+ )
+ print()
+
+ def print_cascaded_passes_with_tensor_sizes(self):
+ for idx, ps in enumerate(self.cascaded_passes):
+ print("%3d %s SRAM used %.1f KB" % (idx * 2, ps, ps.sram_used / 1024))
+ for idx, tens in enumerate(ps.inputs):
+ print(
+ " Input %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
+ % (
+ idx,
+ tens.storage_size() / 1024,
+ tens.storage_shape,
+ tens.mem_area.name,
+ tens.purpose.name,
+ tens.format.name,
+ tens.name,
+ )
+ )
+ for idx, tens in enumerate(ps.intermediates):
+ print(
+ " Intermediate %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
+ % (
+ idx,
+ tens.storage_size() / 1024,
+ tens.storage_shape,
+ tens.mem_area.name,
+ tens.purpose.name,
+ tens.format.name,
+ tens.name,
+ )
+ )
+ for idx, tens in enumerate(ps.outputs):
+ print(
+ " Output %2d %7.1f KB %-24s %-15s %-15s %-20s %s"
+ % (
+ idx,
+ tens.storage_size() / 1024,
+ tens.storage_shape,
+ tens.mem_area.name,
+ tens.purpose.name,
+ tens.format.name,
+ tens.name,
+ )
+ )
+ print()
+
+ def print_high_level_command_stream(self):
+ for idx, cmd in enumerate(self.high_level_command_stream):
+ print("%3d %s" % (idx, cmd))
+
+
+class Graph:
+ def __init__(self, name="<unnamed>", batch_size=1):
+ self.name = name
+ self.batch_size = batch_size
+ self.subgraphs = []
+
+ self.memory_used = {}
+ self.bits_per_element = {}
+ self.total_size = {}
+ self.total_elements = {}
+
+ def get_root_subgraph(self):
+ return self.subgraphs[0]
+
+ def prune_startup_init_pass(self):
+ for sg in self.subgraphs:
+ sg.prune_startup_init_pass()
+
+ def update_consumers(self):
+ for sg in self.subgraphs:
+ sg.update_consumers()
+
+ def refresh_after_modification(self):
+ for sg in self.subgraphs:
+ sg.refresh_after_modification()
+
+ def print_operators(self):
+ for sg in self.subgraphs:
+ sg.print_operators()
+
+ def print_graph(self):
+ for sg in self.subgraphs:
+ sg.print_graph()
+
+ def print_graph_with_tensors(self):
+ for sg in self.subgraphs:
+ sg.print_graph_with_tensors()
+
+ def print_graph_with_tensor_quantization(self):
+ for sg in self.subgraphs:
+ sg.print_graph_with_tensor_quantization()
+
+ def print_passes(self):
+ for sg in self.subgraphs:
+ sg.print_passes()
+
+ def print_passes_with_tensors(self):
+ for sg in self.subgraphs:
+ sg.print_passes_with_tensors()
+
+ def print_cascaded_passes(self):
+ for sg in self.subgraphs:
+ sg.print_cascaded_passes()
+
+ def print_cascaded_passes_with_tensors(self):
+ for sg in self.subgraphs:
+ sg.print_cascaded_passes_with_tensors()
+
+ def print_cascaded_passes_with_tensor_sizes(self):
+ for sg in self.subgraphs:
+ sg.print_cascaded_passes_with_tensor_sizes()
+
+ def print_high_level_command_stream(self):
+ for sg in self.subgraphs:
+ sg.print_high_level_command_stream()