From fad90c2db9e1b3f19f3a3700b17cf69ed08aea04 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Tue, 3 Nov 2020 13:07:40 +0100 Subject: MLBEDSW-3212 Remove CLI opt ifm-ofm-overlap Removed the CLI opt ifm-ofm-overlap Signed-off-by: Patrik Gustavsson Change-Id: I23faa0d10c3e71972c543e22e8155086fce73556 --- ethosu/vela/live_range.py | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) (limited to 'ethosu/vela/live_range.py') diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index b8840355..a29cafe0 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -16,7 +16,6 @@ # Description: # Build a live range graph for tensors in one or more subgraphs. Used for tensor allocation as well as in the scheduler. # Can work with either a pass packed subgraph or a scheduled subgraph. -from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_cascaded_pass from .nn_graph import PassPlacement from .operation import Op from .tensor import MemType @@ -101,7 +100,6 @@ class LiveRange: class LiveRangeGraph: def __init__(self): self.ranges = {} # tens -> range - self.allowed_overlaps = {} # (tens,tens) -> overlap_int self.ignore_tensors = set() self.processed_subgraphs = set() self.current_time = 0 @@ -198,7 +196,7 @@ def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_s def extract_live_ranges_from_passes( sg, target_mem_area, - target_mem_type=set((MemType.Scratch, MemType.Scratch_fast)), + target_mem_type_set=set((MemType.Scratch, MemType.Scratch_fast)), ignore_subgraph_input_output_tensors=False, ): lr_graph = LiveRangeGraph() @@ -209,7 +207,7 @@ def extract_live_ranges_from_passes( # Try to merge live ranges of operations in the NPU subgraphs if sg.placement == PassPlacement.Npu: - merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type) + merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set) for idx, ps in enumerate(sg.passes): ps.time = 2 * idx @@ -217,14 +215,14 @@ def extract_live_ranges_from_passes( time_for_pass = ps.time for tens in ps.inputs + ps.intermediates + ps.outputs: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type): + if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens) rng.mark_usage(time_for_pass) end_time = len(sg.passes) * 2 for tens in sg.output_tensors: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type): + if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens) rng.mark_usage(end_time) @@ -236,7 +234,6 @@ def extract_live_ranges_from_cascaded_passes( sg, target_mem_area, target_mem_type_set, - use_ifm_ofm_overlap=True, ignore_subgraph_input_output_tensors=False, lr_graph=None, allocation_alignment=Tensor.AllocationQuantum, @@ -279,7 +276,7 @@ def extract_live_ranges_from_cascaded_passes( # Use default allocation alignment of 16 for Npu tensors npu_sg = cps_primary_op.attrs["subgraph"] lr_graph = extract_live_ranges_from_cascaded_passes( - npu_sg, target_mem_area, target_mem_type_set, use_ifm_ofm_overlap, False, lr_graph, + npu_sg, target_mem_area, target_mem_type_set, False, lr_graph, ) # Set the new time after handling the Npu subgraph time_for_pass = lr_graph.current_time @@ -291,20 +288,6 @@ def extract_live_ranges_from_cascaded_passes( rng = lr_graph.get_or_create_range(tens, allocation_alignment) rng.mark_usage(time_for_pass) - if use_ifm_ofm_overlap: - # fill allowed overlap for ifm and ofm tensor - ifm_tensor = cps.passes[0].ifm_tensor - ofm_tensor = cps.passes[-1].ofm_tensor - if ( - ifm_tensor is not None - and ofm_tensor is not None - and not tensor_should_be_ignored(lr_graph, ifm_tensor, target_mem_area, target_mem_type_set) - and not tensor_should_be_ignored(lr_graph, ofm_tensor, target_mem_area, target_mem_type_set) - ): - lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass( - cps - ) - lr_graph.current_time += 2 end_time = 0 -- cgit v1.2.1