aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/live_range.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2021-05-27 18:49:40 +0100
committerTim Hall <tim.hall@arm.com>2021-05-27 18:57:39 +0100
commitd8339a75c9b655c0507e34238078fdad068b4023 (patch)
tree36a14726b30760169a83c0356803b480992fade8 /ethosu/vela/live_range.py
parent64556f32ff7bfca6036a6598034464b13b64a4ef (diff)
downloadethos-u-vela-d8339a75c9b655c0507e34238078fdad068b4023.tar.gz
MLBEDSW-4034: New Scheduler Size or Performance Optimisation
- Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r--ethosu/vela/live_range.py125
1 files changed, 118 insertions, 7 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index de001e56..d75a167d 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -18,10 +18,14 @@
# Can work with either a pass packed subgraph or a scheduled subgraph.
from typing import List
+import numpy as np
+
from .nn_graph import PassPlacement
from .operation import Op
+from .tensor import MemArea
from .tensor import MemType
from .tensor import Tensor
+from .tensor import TensorPurpose
class LiveRange:
@@ -32,6 +36,7 @@ class LiveRange:
self.size = 0
self.name = ""
self.alignment = alignment
+ self.mem_area = tens.mem_area if tens else MemArea.Unknown
if tens:
self.add_tensor(tens)
@@ -52,15 +57,19 @@ class LiveRange:
self.tensors.append(tens)
- def mark_usage(self, op_time):
- if op_time == -1:
+ def mark_usage(self, op_time, op_length=1):
+ op_time_start = max(op_time, 0)
+ op_time_end = op_time + op_length
+ if op_time_end <= op_time_start:
return
- op_time_start = op_time
- op_time_end = op_time + 1
self.start_time = min(self.start_time, op_time_start)
self.end_time = max(self.end_time, op_time_end)
+ def set_buffer_size(self, buffer_size):
+ self.size = buffer_size
+ self.mem_area = MemArea.Sram
+
def overlaps_ranges(self, other):
return max(self.start_time, other.start_time) < min(self.end_time, other.end_time)
@@ -106,6 +115,7 @@ class LiveRangeGraph:
self.ignore_tensors = set()
self.processed_subgraphs = set()
self.current_time = 0
+ self.end_time = None
def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum):
# Return the live range of the tensor (or any of its clones)
@@ -127,6 +137,23 @@ class LiveRangeGraph:
self.ranges[out_tens] = live_range
return live_range
+ def update_endtime(self):
+ self.end_time = 0
+ for rng in self.ranges.values():
+ self.end_time = max(self.end_time, rng.end_time)
+ return self.end_time + 1
+
+ def get_temporal_memory_usage(self, target_mem_area):
+ if not self.end_time:
+ self.update_endtime()
+ usage = np.zeros(self.end_time, dtype=np.int32)
+ for rng in self.ranges.values():
+ if rng.mem_area == target_mem_area:
+ # End time is inclusive
+ usage[rng.start_time : rng.end_time + 1] += rng.size
+
+ return usage
+
def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
@@ -279,9 +306,7 @@ def extract_live_ranges_from_cascaded_passes(
# is called. Go into said subgraph and extract live ranges before continuing.
# Use default allocation alignment of 16 for Npu tensors
npu_sg = cps_primary_op.attrs["subgraph"]
- lr_graph = extract_live_ranges_from_cascaded_passes(
- npu_sg, target_mem_area, target_mem_type_set, False, lr_graph,
- )
+ lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph)
# Set the new time after handling the Npu subgraph
time_for_pass = lr_graph.current_time
cps.time = time_for_pass
@@ -308,3 +333,89 @@ def extract_live_ranges_from_cascaded_passes(
# Add subgraph to set of processed subgraphs
lr_graph.processed_subgraphs.add(sg)
return lr_graph
+
+
+def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph):
+ assert lr_graph is not None
+ sg_time = lr_graph.current_time
+ for ps in sg.passes:
+ for tens in ps.inputs + ps.outputs + ps.intermediates:
+ if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
+ lr_graph, tens, target_mem_area, target_mem_type_set
+ ):
+ continue
+
+ rng = lr_graph.get_or_create_range(tens)
+ rng.mark_usage(sg_time)
+
+ for sched_op, op_info in sg.schedule.cost_map.items():
+ if op_info.npu_weights_tensor and not (
+ tensor_should_be_ignored(lr_graph, op_info.npu_weights_tensor, target_mem_area, target_mem_type_set)
+ ):
+ rng = lr_graph.get_or_create_range(op_info.npu_weights_tensor)
+ rng.mark_usage(sg_time)
+
+ lr_graph.current_time += 1
+ return lr_graph
+
+
+def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
+ time_for_cascade = {}
+ for sched_op in sg.sched_ops:
+ op_info = sg.schedule.cost_map[sched_op]
+ cascade = op_info.cascade
+ cascade_info = sg.schedule.cascades.get(cascade, None)
+
+ time_to_set = time_for_cascade.get(cascade, lr_graph.current_time)
+
+ op_info.time_index = time_to_set
+
+ # Mark usage for all tensors related to this Pass
+ ps = sched_op.parent_ps
+ for tens in ps.inputs + ps.outputs + ps.intermediates:
+ if (
+ target_mem_area == MemArea.Sram
+ and cascade_info
+ and tens == ps.ifm_tensor
+ and sched_op in cascade_info.buffers
+ ):
+ # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified
+ # for enabling temporal memory snapshots without modifying the original Tensor
+ rng = lr_graph.get_or_create_range(tens)
+ rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes())
+ elif (
+ tens.purpose == TensorPurpose.Weights
+ or tens.purpose == TensorPurpose.FSBias
+ or tens.mem_type not in target_mem_type_set
+ or tens.mem_area != target_mem_area
+ ):
+ continue
+
+ else:
+ rng = lr_graph.get_or_create_range(tens)
+
+ rng.mark_usage(time_to_set)
+
+ weight_tens = op_info.buffered_weight_tensor
+ if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area:
+ rng = lr_graph.get_or_create_range(weight_tens)
+ if weight_tens.pre_buffer:
+ rng.mark_usage(time_to_set - 1, 2)
+ else:
+ rng.mark_usage(time_to_set)
+
+ if time_to_set == lr_graph.current_time:
+ lr_graph.current_time += 2
+
+ if cascade != 0:
+ time_for_cascade[cascade] = time_to_set
+
+ end_time = lr_graph.update_endtime()
+
+ for tens in sg.output_tensors:
+ if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
+ continue
+ rng = lr_graph.get_or_create_range(tens)
+ rng.mark_usage(end_time)
+
+ return lr_graph