aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/live_range.py34
-rw-r--r--ethosu/vela/scheduler.py4
-rw-r--r--ethosu/vela/tensor_allocation.py8
3 files changed, 15 insertions, 31 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 0b94b197..5aec0dfb 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -112,7 +112,6 @@ class LiveRangeGraph:
def __init__(self):
self.lrs: List[LiveRange] = [] # List of all created ranges
self.ranges = {} # tens -> range
- self.ignore_tensors = set()
self.processed_subgraphs = set()
self.current_time = 0
self.end_time = None
@@ -151,17 +150,17 @@ class LiveRangeGraph:
return usage
-def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
return True
- if tens in lr_graph.ignore_tensors:
- return True
return False
-def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set):
+def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
def _tensor_should_be_ignored(tens):
- return tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set)
+ if tens in sg.input_tensors + sg.output_tensors:
+ return True
+ return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set)
# Tries to merge ifm/ofm live ranges of elementwise op
if sched_op.op_type.is_elementwise_op():
@@ -198,12 +197,7 @@ def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_
def extract_live_ranges_from_cascaded_passes(
- sg,
- target_mem_area,
- target_mem_type_set,
- ignore_subgraph_input_output_tensors=False,
- lr_graph=None,
- cpu_tensor_alignment=Tensor.AllocationQuantum,
+ sg, target_mem_area, target_mem_type_set, lr_graph=None, cpu_tensor_alignment=Tensor.AllocationQuantum,
):
if lr_graph is None:
lr_graph = LiveRangeGraph()
@@ -212,17 +206,13 @@ def extract_live_ranges_from_cascaded_passes(
# if subgraph has been processed already, return the lr_graph as is
return lr_graph
- if ignore_subgraph_input_output_tensors:
- lr_graph.ignore_tensors.update(sg.input_tensors)
- lr_graph.ignore_tensors.update(sg.output_tensors)
-
for cps in sg.cascaded_passes:
cps.time = lr_graph.current_time
time_for_pass = cps.time
for tens in cps.inputs:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(time_for_pass)
@@ -244,7 +234,7 @@ def extract_live_ranges_from_cascaded_passes(
cps.time = time_for_pass
for tens in cps.intermediates + cps.outputs:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(time_for_pass)
@@ -257,7 +247,7 @@ def extract_live_ranges_from_cascaded_passes(
end_time = max(end_time, rng.end_time)
for tens in sg.output_tensors:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(end_time)
@@ -273,7 +263,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
for ps in sg.passes:
for tens in ps.inputs + ps.outputs + ps.intermediates:
if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
- lr_graph, tens, target_mem_area, target_mem_type_set
+ tens, target_mem_area, target_mem_type_set
):
continue
rng = lr_graph.get_or_create_range(tens)
@@ -281,7 +271,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
for _, op_info in sg.schedule.cost_map.items():
for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]:
- if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)):
+ if tensor and not (tensor_should_be_ignored(tensor, target_mem_area, target_mem_type_set)):
rng = lr_graph.get_or_create_range(tensor)
rng.mark_usage(sg_time)
@@ -292,7 +282,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
time_for_cascade = {}
for sched_op in sg.sched_ops:
- merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set)
+ merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set)
op_info = sg.schedule.cost_map[sched_op]
cascade = op_info.cascade
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 2ac47878..782e8d98 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -425,7 +425,7 @@ class Scheduler:
lr_graph = live_range.LiveRangeGraph()
for mem_area, mem_type_set in memories_list:
live_range.extract_live_ranges_from_cascaded_passes(
- self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+ self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum,
)
# Populate time-array with memory used by live ranges
@@ -918,7 +918,7 @@ class Scheduler:
lr_graph = live_range.LiveRangeGraph()
for mem_area, mem_type_set in memories_list:
live_range.extract_live_ranges_from_cascaded_passes(
- self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+ self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum,
)
# Iterate over live ranges and evict tensors that doesn't fit
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index 4b5e5e42..db261ae5 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -182,14 +182,8 @@ def allocate(
cpu_tensor_alignment=Tensor.AllocationQuantum,
):
# Allocates addresses to tensors, returns False if tensors could not be fit within max_size
- ignore_subgraph_input_output_tensors = False
lrs = live_range.extract_live_ranges_from_cascaded_passes(
- sg,
- mem_area,
- mem_type_set,
- ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
- lr_graph=lr_graph,
- cpu_tensor_alignment=cpu_tensor_alignment,
+ sg, mem_area, mem_type_set, lr_graph=lr_graph, cpu_tensor_alignment=cpu_tensor_alignment,
)
total_sz = 0
if lrs.ranges: