aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/live_range.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r--ethosu/vela/live_range.py27
1 files changed, 5 insertions, 22 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index b884035..a29cafe 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -16,7 +16,6 @@
# Description:
# Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler.
# Can work with either a pass packed subgraph or a scheduled subgraph.
-from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass
from .nn_graph import PassPlacement
from .operation import Op
from .tensor import MemType
@@ -101,7 +100,6 @@ class LiveRange:
class LiveRangeGraph:
def __init__(self):
self.ranges = {} # tens -> range
- self.allowed_overlaps = {} # (tens,tens) -> overlap_int
self.ignore_tensors = set()
self.processed_subgraphs = set()
self.current_time = 0
@@ -198,7 +196,7 @@ def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_s
def extract_live_ranges_from_passes(
sg,
target_mem_area,
- target_mem_type=set((MemType.Scratch, MemType.Scratch_fast)),
+ target_mem_type_set=set((MemType.Scratch, MemType.Scratch_fast)),
ignore_subgraph_input_output_tensors=False,
):
lr_graph = LiveRangeGraph()
@@ -209,7 +207,7 @@ def extract_live_ranges_from_passes(
# Try to merge live ranges of operations in the NPU subgraphs
if sg.placement == PassPlacement.Npu:
- merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type)
+ merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set)
for idx, ps in enumerate(sg.passes):
ps.time = 2 * idx
@@ -217,14 +215,14 @@ def extract_live_ranges_from_passes(
time_for_pass = ps.time
for tens in ps.inputs + ps.intermediates + ps.outputs:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
+ if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens)
rng.mark_usage(time_for_pass)
end_time = len(sg.passes) * 2
for tens in sg.output_tensors:
- if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type):
+ if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens)
rng.mark_usage(end_time)
@@ -236,7 +234,6 @@ def extract_live_ranges_from_cascaded_passes(
sg,
target_mem_area,
target_mem_type_set,
- use_ifm_ofm_overlap=True,
ignore_subgraph_input_output_tensors=False,
lr_graph=None,
allocation_alignment=Tensor.AllocationQuantum,
@@ -279,7 +276,7 @@ def extract_live_ranges_from_cascaded_passes(
# Use default allocation alignment of 16 for Npu tensors
npu_sg = cps_primary_op.attrs["subgraph"]
lr_graph = extract_live_ranges_from_cascaded_passes(
- npu_sg, target_mem_area, target_mem_type_set, use_ifm_ofm_overlap, False, lr_graph,
+ npu_sg, target_mem_area, target_mem_type_set, False, lr_graph,
)
# Set the new time after handling the Npu subgraph
time_for_pass = lr_graph.current_time
@@ -291,20 +288,6 @@ def extract_live_ranges_from_cascaded_passes(
rng = lr_graph.get_or_create_range(tens, allocation_alignment)
rng.mark_usage(time_for_pass)
- if use_ifm_ofm_overlap:
- # fill allowed overlap for ifm and ofm tensor
- ifm_tensor = cps.passes[0].ifm_tensor
- ofm_tensor = cps.passes[-1].ofm_tensor
- if (
- ifm_tensor is not None
- and ofm_tensor is not None
- and not tensor_should_be_ignored(lr_graph, ifm_tensor, target_mem_area, target_mem_type_set)
- and not tensor_should_be_ignored(lr_graph, ofm_tensor, target_mem_area, target_mem_type_set)
- ):
- lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass(
- cps
- )
-
lr_graph.current_time += 2
end_time = 0