aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2021-10-27 13:58:03 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2021-10-27 13:58:03 +0200
commit0ae2848580d376c66d6a3c3b0fb55b30a0234247 (patch)
treee36aacb086a50d61a3c413faef3e83f78938ed27
parent2b5939f639d3ceb9fcf75c2edc78d128676119b3 (diff)
downloadethos-u-vela-0ae2848580d376c66d6a3c3b0fb55b30a0234247.tar.gz
MLBEDSW-5450 MLCE: Vela to handle skip Tensor
Added checks to avoid merging elementwise op live ranges for subgraph inputs and outputs, which sometimes caused problems when parts of the network run on CPU. Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com> Change-Id: Id07ab277a205b8550d19a276559f8904b9a4b4be
-rw-r--r--ethosu/vela/live_range.py34
-rw-r--r--ethosu/vela/scheduler.py4
-rw-r--r--ethosu/vela/tensor_allocation.py8
3 files changed, 15 insertions, 31 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 0b94b19..5aec0df 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -112,7 +112,6 @@ class LiveRangeGraph:
def __init__(self):
self.lrs: List[LiveRange] = [] # List of all created ranges
self.ranges = {} # tens -> range
- self.ignore_tensors = set()
self.processed_subgraphs = set()
self.current_time = 0
self.end_time = None
@@ -151,17 +150,17 @@ class LiveRangeGraph:
return usage
-def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
return True
- if tens in lr_graph.ignore_tensors:
- return True
return False
-def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set):
+def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
def _tensor_should_be_ignored(tens):
- return tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set)
+ if tens in sg.input_tensors + sg.output_tensors:
+ return True
+ return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set)
# Tries to merge ifm/ofm live ranges of elementwise op
if sched_op.op_type.is_elementwise_op():
@@ -198,12 +197,7 @@ def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_
def extract_live_ranges_from_cascaded_passes(
- sg,
- target_mem_area,
- target_mem_type_set,
- ignore_subgraph_input_output_tensors=False,
- lr_graph=None,
- cpu_tensor_alignment=Tensor.AllocationQuantum,
+ sg, target_mem_area, target_mem_type_set, lr_graph=None, cpu_tensor_alignment=Tensor.AllocationQuantum,
):
if lr_graph is None:
lr_graph = LiveRangeGraph()
@@ -212,17 +206,13 @@ def extract_live_ranges_from_cascaded_passes(
# if subgraph has been processed already, return the lr_graph as is
return lr_graph
- if ignore_subgraph_input_output_tensors:
- lr_graph.ignore_tensors.update(sg.input_tensors)
- lr_graph.ignore_tensors.update(sg.output_tensors)
-
for cps in sg.cascaded_passes:
cps.time = lr_graph.current_time
time_for_pass = cps.time
for tens in cps.inputs:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(time_for_pass)
@@ -244,7 +234,7 @@ def extract_live_ranges_from_cascaded_passes(
cps.time = time_for_pass
for tens in cps.intermediates + cps.outputs:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(time_for_pass)
@@ -257,7 +247,7 @@ def extract_live_ranges_from_cascaded_passes(
end_time = max(end_time, rng.end_time)
for tens in sg.output_tensors:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
+ if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(end_time)
@@ -273,7 +263,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
for ps in sg.passes:
for tens in ps.inputs + ps.outputs + ps.intermediates:
if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored(
- lr_graph, tens, target_mem_area, target_mem_type_set
+ tens, target_mem_area, target_mem_type_set
):
continue
rng = lr_graph.get_or_create_range(tens)
@@ -281,7 +271,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
for _, op_info in sg.schedule.cost_map.items():
for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]:
- if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)):
+ if tensor and not (tensor_should_be_ignored(tensor, target_mem_area, target_mem_type_set)):
rng = lr_graph.get_or_create_range(tensor)
rng.mark_usage(sg_time)
@@ -292,7 +282,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_
def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph):
time_for_cascade = {}
for sched_op in sg.sched_ops:
- merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set)
+ merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set)
op_info = sg.schedule.cost_map[sched_op]
cascade = op_info.cascade
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 2ac4787..782e8d9 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -425,7 +425,7 @@ class Scheduler:
lr_graph = live_range.LiveRangeGraph()
for mem_area, mem_type_set in memories_list:
live_range.extract_live_ranges_from_cascaded_passes(
- self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+ self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum,
)
# Populate time-array with memory used by live ranges
@@ -918,7 +918,7 @@ class Scheduler:
lr_graph = live_range.LiveRangeGraph()
for mem_area, mem_type_set in memories_list:
live_range.extract_live_ranges_from_cascaded_passes(
- self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum,
+ self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum,
)
# Iterate over live ranges and evict tensors that doesn't fit
diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py
index 4b5e5e4..db261ae 100644
--- a/ethosu/vela/tensor_allocation.py
+++ b/ethosu/vela/tensor_allocation.py
@@ -182,14 +182,8 @@ def allocate(
cpu_tensor_alignment=Tensor.AllocationQuantum,
):
# Allocates addresses to tensors, returns False if tensors could not be fit within max_size
- ignore_subgraph_input_output_tensors = False
lrs = live_range.extract_live_ranges_from_cascaded_passes(
- sg,
- mem_area,
- mem_type_set,
- ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
- lr_graph=lr_graph,
- cpu_tensor_alignment=cpu_tensor_alignment,
+ sg, mem_area, mem_type_set, lr_graph=lr_graph, cpu_tensor_alignment=cpu_tensor_alignment,
)
total_sz = 0
if lrs.ranges: