From fba0a7dc43373a69f3c0792587d3d9b0cc010ccf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Tue, 11 Oct 2022 20:41:41 +0200 Subject: MLBEDSW-6931: Refactoring merge elementwise ops Change code in cascade builder to instead use common functionality in live range. Signed-off-by: Johan Alfven Change-Id: I7bbd7ea3d1e7e085813e9d93256a54e6bab2267b --- ethosu/vela/cascade_builder.py | 52 ++++-------------------------------------- ethosu/vela/live_range.py | 22 +++++++++++++++--- 2 files changed, 24 insertions(+), 50 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py index ebe2f133..9c84ba8d 100644 --- a/ethosu/vela/cascade_builder.py +++ b/ethosu/vela/cascade_builder.py @@ -16,9 +16,8 @@ # # Description: # Groups Operators in a schedule together to form Cascades. -from collections import namedtuple - from .high_level_command_to_npu_op import ifm_ifm2_correct_order +from .live_range import ofm_can_reuse_ifm from .numeric_util import round_up from .operation import NpuBlockType from .operation import Op @@ -105,46 +104,6 @@ class CascadeBuilder: and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE ) - def _is_mergeable(self, sched_op) -> bool: - # Code based on merge_elementwise_op_ranges from live_range.py - - if not sched_op.op_type.is_elementwise_op(): - return False - - elem_op = sched_op.parent_op - - # Check if overwriting the inputs can be allowed - OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"]) - outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm) - - # check output tensor only has one producer - if len(outp.tens.ops) != 1: - return False - - inps = [] - if elem_op.ifm is not None: - inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm)) - if elem_op.ifm2 is not None: - inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2)) - - # find an input tensor that can be overwritten by the output - for inp in inps: - if ( - # check op input and output shapes allow overlapping - inp.op_shape == outp.op_shape - # check input tensor is valid - and inp.tens is not None - and inp.tens.shape != [] - # check input and output tensors are compatible - and inp.tens.format == outp.tens.format - and inp.tens.dtype == outp.tens.dtype - # check input tensor only has one consumer - and len(inp.tens.consumer_list) == 1 - ): - return True - - return False - def _estimate_sram_usage(self, sched_op, cost) -> int: """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM""" ifm2_size = sched_op.ifm2_size_in_bytes() @@ -155,17 +114,16 @@ class CascadeBuilder: cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements() * sched_op.ifm.dtype.size_in_bytes() ) - if sched_op.requires_full_ofm: + if ofm_can_reuse_ifm(sched_op): + # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0 + ofm_size = 0 + elif sched_op.requires_full_ofm: ofm_size = sched_op.ofm_size_in_bytes() else: ofm_size = ( cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes() ) - if self._is_mergeable(sched_op): - # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0 - ofm_size = 0 - return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0) @staticmethod diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 9b6fe63d..fbb48ecd 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -154,18 +154,21 @@ class LiveRangeGraph: def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): + if target_mem_area is None or target_mem_type_set is None: + return False if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set: return True return False -def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set): +def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None): def _tensor_should_be_ignored(tens): if tens.ifm_write_protected: return True return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set) - # Tries to merge ifm/ofm live ranges of elementwise op + # Check if possible to merge ifm/ofm live ranges of elementwise op + ifm_tens = None if sched_op.op_type.is_elementwise_op(): elem_op = sched_op.parent_op if not _tensor_should_be_ignored(elem_op.ofm): @@ -195,9 +198,22 @@ def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_ # check output tensor only has one producer and len(outp.tens.ops) == 1 ): - lr_graph.fuse_ranges(inp.tens, outp.tens) + ifm_tens = inp.tens break + return ifm_tens + + +def ofm_can_reuse_ifm(sched_op, target_mem_area=None, target_mem_type_set=None): + ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set) + return ifm is not None + + +def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set): + ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set) + if ifm: + lr_graph.fuse_ranges(ifm, sched_op.parent_op.ofm) + def extract_live_ranges_from_cascaded_passes( sg, -- cgit v1.2.1