From 98bfecd20f600a12de2f6a282d2fdbddb23dc081 Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Mon, 21 Jun 2021 17:22:20 +0200 Subject: MLBEDSW-4807 Elementwise IFM/OFM overlap Reinstated allowing the IFM and OFM tensor to overlap for Elementwise operations. Signed-off-by: Jacob Bohlin Change-Id: Ide6db7781f3ca7a36c8ff9e3efdc7943a7bf6d7f --- ethosu/vela/live_range.py | 133 +++++++++++++--------------------------------- 1 file changed, 36 insertions(+), 97 deletions(-) (limited to 'ethosu/vela') diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index b687a9e7..2795b668 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -20,7 +20,6 @@ from typing import List import numpy as np -from .nn_graph import PassPlacement from .operation import Op from .tensor import MemArea from .tensor import MemType @@ -167,98 +166,40 @@ def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_se return False -# Tries merging of ifm/ofm live ranges for memory only ops and elementwise ops -def merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set): - for ps in sg.passes: - if ps.placement == PassPlacement.MemoryOnly: - # For memory only passes, e.g. Reshape. Add input and output tensor to the same LiveRange - input_tensor = ps.inputs[0] - output_tensor = ps.outputs[0] - if not tensor_should_be_ignored(lr_graph, input_tensor, target_mem_area, target_mem_type_set) and not ( - tensor_should_be_ignored(lr_graph, output_tensor, target_mem_area, target_mem_type_set) - ): - lr_graph.fuse_ranges(input_tensor, output_tensor) - elif ps.is_element_wise: - merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set) - - -# Tries to merge ifm/ofm live of elementwise op -def merge_elementwise_op_ranges(ps, lr_graph, target_mem_area, target_mem_type_set): - elem_op = None - for op in ps.ops: - if op.type.is_elementwise_op(): - assert elem_op is None - elem_op = op - - if elem_op is not None and not tensor_should_be_ignored( - lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set - ): - # Check if overwriting the inputs can be allowed - if elem_op.type not in (Op.SHL, Op.SHR): - inps = [] - if ( - elem_op.ifm is not None - and elem_op.ifm.shape != [] - and elem_op.ifm.mem_area == target_mem_area - and elem_op.ifm.mem_type in target_mem_type_set - ): - inps.append(elem_op.ifm) - if ( - elem_op.ifm2 is not None - and elem_op.ifm2.shape != [] - and elem_op.ifm2.mem_area == target_mem_area - and elem_op.ifm.mem_type in target_mem_type_set - ): - inps.append(elem_op.ifm2) - - if len(inps) > 0: - for i, inp in enumerate(inps): - # check input format, dtype, broadcasting or if there are more input consumers - if ( - inp.format == elem_op.ofm.format - and inp.dtype == elem_op.ofm.dtype - and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0] - and (len(inp.consumer_list) == 1 and len(inp.ops) == 1) - ): - lr_graph.fuse_ranges(inp, elem_op.ofm) - break - - -def extract_live_ranges_from_passes( - sg, target_mem_area, target_mem_type_set=None, ignore_subgraph_input_output_tensors=False, -): - lr_graph = LiveRangeGraph() - - if ignore_subgraph_input_output_tensors: - lr_graph.ignore_tensors.update(sg.input_tensors) - lr_graph.ignore_tensors.update(sg.output_tensors) - - if target_mem_type_set is None: - target_mem_type_set = set((MemType.Scratch, MemType.Scratch_fast)) - - # Try to merge live ranges of operations in the NPU subgraphs - if sg.placement == PassPlacement.Npu: - merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set) - - for idx, ps in enumerate(sg.passes): - ps.time = 2 * idx - - time_for_pass = ps.time - - for tens in ps.inputs + ps.intermediates + ps.outputs: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): - continue - rng = lr_graph.get_or_create_range(tens) - rng.mark_usage(time_for_pass) - - end_time = len(sg.passes) * 2 - for tens in sg.output_tensors: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): - continue - rng = lr_graph.get_or_create_range(tens) - rng.mark_usage(end_time) +def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set): + # Tries to merge ifm/ofm live ranges of elementwise op + if sched_op.op_type.is_elementwise_op(): + elem_op = sched_op.parent_op + if not tensor_should_be_ignored(lr_graph, elem_op.ofm, target_mem_area, target_mem_type_set): + # Check if overwriting the inputs can be allowed + if elem_op.type not in (Op.SHL, Op.SHR): + inps = [] + if ( + elem_op.ifm is not None + and elem_op.ifm.shape != [] + and elem_op.ifm.mem_area == target_mem_area + and elem_op.ifm.mem_type in target_mem_type_set + ): + inps.append(elem_op.ifm) + if ( + elem_op.ifm2 is not None + and elem_op.ifm2.shape != [] + and elem_op.ifm2.mem_area == target_mem_area + and elem_op.ifm.mem_type in target_mem_type_set + ): + inps.append(elem_op.ifm2) - return lr_graph + if len(inps) > 0: + for i, inp in enumerate(inps): + # check input format, dtype, broadcasting or if there are more input consumers + if ( + inp.format == elem_op.ofm.format + and inp.dtype == elem_op.ofm.dtype + and elem_op.ifm_shapes[i] == elem_op.ofm_shapes[0] + and (len(inp.consumer_list) == 1 and len(inp.ops) == 1) + ): + lr_graph.fuse_ranges(inp, elem_op.ofm) + break def extract_live_ranges_from_cascaded_passes( @@ -280,10 +221,6 @@ def extract_live_ranges_from_cascaded_passes( lr_graph.ignore_tensors.update(sg.input_tensors) lr_graph.ignore_tensors.update(sg.output_tensors) - # Try to merge live ranges of operations in the NPU subgraphs - if sg.placement == PassPlacement.Npu: - merge_op_ranges(sg, lr_graph, target_mem_area, target_mem_type_set) - for cps in sg.cascaded_passes: cps.time = lr_graph.current_time @@ -347,7 +284,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ rng = lr_graph.get_or_create_range(tens) rng.mark_usage(sg_time) - for sched_op, op_info in sg.schedule.cost_map.items(): + for _, op_info in sg.schedule.cost_map.items(): for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]: if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)): rng = lr_graph.get_or_create_range(tensor) @@ -360,6 +297,8 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph): time_for_cascade = {} for sched_op in sg.sched_ops: + merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set) + op_info = sg.schedule.cost_map[sched_op] cascade = op_info.cascade cascade_info = sg.schedule.cascades.get(cascade, None) -- cgit v1.2.1