From a151f597bef24e0d8b51dbe833338057e8bcbc92 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Fri, 16 Oct 2020 13:59:52 +0200 Subject: MLBEDSW-3212 Enable overlap of elementwise input/output Enable overlap of elementwise input/output Signed-off-by: Patrik Gustavsson Change-Id: I6e6f11953319c843c8203bf038f96778df194332 --- ethosu/vela/live_range.py | 177 +++++++++++++++++++-------------------- ethosu/vela/scheduler.py | 5 +- ethosu/vela/tensor_allocation.py | 1 - 3 files changed, 85 insertions(+), 98 deletions(-) diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 23026c79..b8840355 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -98,18 +98,6 @@ class LiveRange: self.alignment = max(self.alignment, alignment) -def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area): - for ps in sg.passes: - if ps.placement == PassPlacement.MemoryOnly: - # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange - input_tensor = ps.inputs[0] - output_tensor = ps.outputs[0] - if not tensor_should_be_ignored(input_tensor, target_mem_area) and not tensor_should_be_ignored( - output_tensor, target_mem_area - ): - lr_graph.fuse_ranges(input_tensor, output_tensor) - - class LiveRangeGraph: def __init__(self): self.ranges = {} # tens -> range @@ -138,10 +126,79 @@ class LiveRangeGraph: return live_range +def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): + if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set: + return True + if tens in lr_graph.ignore_tensors: + return True + if tens.name.endswith("reshape_shape_npu"): + # Reshape tensor, no need to allocate + lr_graph.ignore_tensors.add(tens) + return True + return False + + +# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops +def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set): + for ps in sg.passes: + if ps.placement == PassPlacement.MemoryOnly: + # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange + input_tensor = ps.inputs[0] + output_tensor = ps.outputs[0] + if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not ( + tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set) + ): + lr_graph.fuse_ranges(input_tensor, output_tensor) + elif ps.is_element_wise: + merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set) + + +# Tries to merge ifm/ofm live of elementwise op +def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set): + elem_op = None + for op in ps.ops: + if op.type.is_elementwise_op(): + assert elem_op is None + elem_op = op + + if elem_op is not None and not tensor_should_be_ignored( + lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set + ): + # Check if overwriting the inputs can be allowed + if elem_op.type not in (Op.SHL, Op.SHR): + inps = [] + if ( + elem_op.ifm is not None + and elem_op.ifm.shape != [] + and elem_op.ifm.mem_area == target_mem_area + and elem_op.ifm.mem_type in target_mem_type_set + ): + inps.append(elem_op.ifm) + if ( + elem_op.ifm2 is not None + and elem_op.ifm2.shape != [] + and elem_op.ifm2.mem_area == target_mem_area + and elem_op.ifm.mem_type in target_mem_type_set + ): + inps.append(elem_op.ifm2) + + if len(inps) > 0: + for inp in inps: + # check input format, dtype, broadcasting or if there are more input consumers + if ( + inp.format == elem_op.ofm.format + and inp.dtype == elem_op.ofm.dtype + and inp.shape == elem_op.ofm.shape + and (len(inp.consumer_list) == 1 and len(inp.ops) == 1) + ): + lr_graph.fuse_ranges(inp, elem_op.ofm) + break + + def extract_live_ranges_from_passes( sg, target_mem_area, - mark_output_tensors_overlapping_with_input_tensors=False, + target_mem_type=set((MemType.Scratch, MemType.Scratch_fast)), ignore_subgraph_input_output_tensors=False, ): lr_graph = LiveRangeGraph() @@ -150,50 +207,24 @@ def extract_live_ranges_from_passes( lr_graph.ignore_tensors.update(sg.input_tensors) lr_graph.ignore_tensors.update(sg.output_tensors) - def tensor_should_be_ignored(tens, target_mem_area): - if tens.mem_area != target_mem_area: - return True - if tens in lr_graph.ignore_tensors: - return True - if tens.name.endswith("reshape_shape_npu"): - # Reshape tensor, no need to allocate - lr_graph.ignore_tensors.add(tens) - return True - return False - - # Merge only memory operations in the NPU subgraphs + # Try to merge live ranges of operations in the NPU subgraphs if sg.placement == PassPlacement.Npu: - merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area) + merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type) for idx, ps in enumerate(sg.passes): ps.time = 2 * idx time_for_pass = ps.time - for tens in ps.inputs: - if tensor_should_be_ignored(tens, target_mem_area): - continue - rng = lr_graph.get_or_create_range(tens) - rng.mark_usage(time_for_pass) - - for tens in ps.intermediates: - if tensor_should_be_ignored(tens, target_mem_area): + for tens in ps.inputs + ps.intermediates + ps.outputs: + if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type): continue rng = lr_graph.get_or_create_range(tens) rng.mark_usage(time_for_pass) - for tens in ps.outputs: - if tensor_should_be_ignored(tens, target_mem_area): - continue - rng = lr_graph.get_or_create_range(tens) - output_time = time_for_pass - if not mark_output_tensors_overlapping_with_input_tensors and ps.is_element_wise: - output_time += 1 - rng.mark_usage(output_time) - end_time = len(sg.passes) * 2 for tens in sg.output_tensors: - if tensor_should_be_ignored(tens, target_mem_area): + if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type): continue rng = lr_graph.get_or_create_range(tens) rng.mark_usage(end_time) @@ -205,7 +236,6 @@ def extract_live_ranges_from_cascaded_passes( sg, target_mem_area, target_mem_type_set, - mark_output_tensors_overlapping_with_input_tensors=False, use_ifm_ofm_overlap=True, ignore_subgraph_input_output_tensors=False, lr_graph=None, @@ -222,41 +252,17 @@ def extract_live_ranges_from_cascaded_passes( lr_graph.ignore_tensors.update(sg.input_tensors) lr_graph.ignore_tensors.update(sg.output_tensors) - def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): - if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set: - return True - if tens in lr_graph.ignore_tensors: - return True - if tens.name.endswith("reshape_shape_npu"): - # Reshape tensor, no need to allocate - lr_graph.ignore_tensors.add(tens) - return True - return False - - def merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set): - for ps in sg.passes: - if ps.placement == PassPlacement.MemoryOnly: - # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange - input_tensor = ps.inputs[0] - output_tensor = ps.outputs[0] - if not tensor_should_be_ignored(input_tensor, target_mem_area, target_mem_type_set) and not ( - tensor_should_be_ignored(output_tensor, target_mem_area, target_mem_type_set) - ): - lr_graph.fuse_ranges(input_tensor, output_tensor) - - # Merge only memory operations in the NPU subgraphs + # Try to merge live ranges of operations in the NPU subgraphs if sg.placement == PassPlacement.Npu: - merge_memory_op_ranges(sg, lr_graph, tensor_should_be_ignored, target_mem_area, target_mem_type_set) + merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set) for cps in sg.cascaded_passes: cps.time = lr_graph.current_time time_for_pass = cps.time - is_element_wise = cps.is_element_wise - for tens in cps.inputs: - if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, allocation_alignment) rng.mark_usage(time_for_pass) @@ -273,33 +279,18 @@ def extract_live_ranges_from_cascaded_passes( # Use default allocation alignment of 16 for Npu tensors npu_sg = cps_primary_op.attrs["subgraph"] lr_graph = extract_live_ranges_from_cascaded_passes( - npu_sg, - target_mem_area, - target_mem_type_set, - mark_output_tensors_overlapping_with_input_tensors, - use_ifm_ofm_overlap, - False, - lr_graph, + npu_sg, target_mem_area, target_mem_type_set, use_ifm_ofm_overlap, False, lr_graph, ) # Set the new time after handling the Npu subgraph time_for_pass = lr_graph.current_time cps.time = time_for_pass - for tens in cps.intermediates: - if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): + for tens in cps.intermediates + cps.outputs: + if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, allocation_alignment) rng.mark_usage(time_for_pass) - for tens in cps.outputs: - if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): - continue - rng = lr_graph.get_or_create_range(tens, allocation_alignment) - output_time = time_for_pass - if not mark_output_tensors_overlapping_with_input_tensors and is_element_wise: - output_time += 1 - rng.mark_usage(output_time) - if use_ifm_ofm_overlap: # fill allowed overlap for ifm and ofm tensor ifm_tensor = cps.passes[0].ifm_tensor @@ -307,8 +298,8 @@ def extract_live_ranges_from_cascaded_passes( if ( ifm_tensor is not None and ofm_tensor is not None - and not tensor_should_be_ignored(ifm_tensor, target_mem_area, target_mem_type_set) - and not tensor_should_be_ignored(ofm_tensor, target_mem_area, target_mem_type_set) + and not tensor_should_be_ignored(lr_graph, ifm_tensor, target_mem_area, target_mem_type_set) + and not tensor_should_be_ignored(lr_graph, ofm_tensor, target_mem_area, target_mem_type_set) ): lr_graph.allowed_overlaps[(ifm_tensor, ofm_tensor)] = calc_allowed_ofm_ifm_overlap_for_cascaded_pass( cps @@ -322,7 +313,7 @@ def extract_live_ranges_from_cascaded_passes( end_time = max(end_time, rng.end_time) for tens in sg.output_tensors: - if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, allocation_alignment) rng.mark_usage(end_time) diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 41e15294..31e6383a 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -480,10 +480,7 @@ class DynamicProgrammingScheduler: def calc_non_local_mem_usage(self): ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu range_set = live_range.extract_live_ranges_from_passes( - self.sg, - self.mem_area, - mark_output_tensors_overlapping_with_input_tensors=True, - ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors, + self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors, ) range_dict = range_set.ranges diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py index 1efcd686..9f14ec4c 100644 --- a/ethosu/vela/tensor_allocation.py +++ b/ethosu/vela/tensor_allocation.py @@ -137,7 +137,6 @@ def allocate_tensors( sg, mem_area, mem_type_set, - mark_output_tensors_overlapping_with_input_tensors=False, use_ifm_ofm_overlap=use_ifm_ofm_overlap, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors, lr_graph=lr_graph, -- cgit v1.2.1