diff options
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r-- | ethosu/vela/cascade_builder.py | 52 |
1 files changed, 5 insertions, 47 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py index ebe2f133..9c84ba8d 100644 --- a/ethosu/vela/cascade_builder.py +++ b/ethosu/vela/cascade_builder.py @@ -16,9 +16,8 @@ # # Description: # Groups Operators in a schedule together to form Cascades. -from collections import namedtuple - from .high_level_command_to_npu_op import ifm_ifm2_correct_order +from .live_range import ofm_can_reuse_ifm from .numeric_util import round_up from .operation import NpuBlockType from .operation import Op @@ -105,46 +104,6 @@ class CascadeBuilder: and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE ) - def _is_mergeable(self, sched_op) -> bool: - # Code based on merge_elementwise_op_ranges from live_range.py - - if not sched_op.op_type.is_elementwise_op(): - return False - - elem_op = sched_op.parent_op - - # Check if overwriting the inputs can be allowed - OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"]) - outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm) - - # check output tensor only has one producer - if len(outp.tens.ops) != 1: - return False - - inps = [] - if elem_op.ifm is not None: - inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm)) - if elem_op.ifm2 is not None: - inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2)) - - # find an input tensor that can be overwritten by the output - for inp in inps: - if ( - # check op input and output shapes allow overlapping - inp.op_shape == outp.op_shape - # check input tensor is valid - and inp.tens is not None - and inp.tens.shape != [] - # check input and output tensors are compatible - and inp.tens.format == outp.tens.format - and inp.tens.dtype == outp.tens.dtype - # check input tensor only has one consumer - and len(inp.tens.consumer_list) == 1 - ): - return True - - return False - def _estimate_sram_usage(self, sched_op, cost) -> int: """Estimate the SRAM required for the Op if all FeatureMaps are in SRAM""" ifm2_size = sched_op.ifm2_size_in_bytes() @@ -155,17 +114,16 @@ class CascadeBuilder: cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements() * sched_op.ifm.dtype.size_in_bytes() ) - if sched_op.requires_full_ofm: + if ofm_can_reuse_ifm(sched_op): + # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0 + ofm_size = 0 + elif sched_op.requires_full_ofm: ofm_size = sched_op.ofm_size_in_bytes() else: ofm_size = ( cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes() ) - if self._is_mergeable(sched_op): - # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0 - ofm_size = 0 - return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0) @staticmethod |