aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/cascade_builder.py
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-11 20:41:41 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-20 10:45:52 +0200
commitfba0a7dc43373a69f3c0792587d3d9b0cc010ccf (patch)
treebf1a497d8ca6c080988ab920110981df2f8badff /ethosu/vela/cascade_builder.py
parent673683bb828cd552f1970922e3c61079607332b2 (diff)
downloadethos-u-vela-fba0a7dc43373a69f3c0792587d3d9b0cc010ccf.tar.gz
MLBEDSW-6931: Refactoring merge elementwise ops
Change code in cascade builder to instead use common functionality in live range. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I7bbd7ea3d1e7e085813e9d93256a54e6bab2267b
Diffstat (limited to 'ethosu/vela/cascade_builder.py')
-rw-r--r--ethosu/vela/cascade_builder.py52
1 files changed, 5 insertions, 47 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index ebe2f133..9c84ba8d 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -16,9 +16,8 @@
#
# Description:
# Groups Operators in a schedule together to form Cascades.
-from collections import namedtuple
-
from .high_level_command_to_npu_op import ifm_ifm2_correct_order
+from .live_range import ofm_can_reuse_ifm
from .numeric_util import round_up
from .operation import NpuBlockType
from .operation import Op
@@ -105,46 +104,6 @@ class CascadeBuilder:
and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE
)
- def _is_mergeable(self, sched_op) -> bool:
- # Code based on merge_elementwise_op_ranges from live_range.py
-
- if not sched_op.op_type.is_elementwise_op():
- return False
-
- elem_op = sched_op.parent_op
-
- # Check if overwriting the inputs can be allowed
- OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
- outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
-
- # check output tensor only has one producer
- if len(outp.tens.ops) != 1:
- return False
-
- inps = []
- if elem_op.ifm is not None:
- inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
- if elem_op.ifm2 is not None:
- inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
-
- # find an input tensor that can be overwritten by the output
- for inp in inps:
- if (
- # check op input and output shapes allow overlapping
- inp.op_shape == outp.op_shape
- # check input tensor is valid
- and inp.tens is not None
- and inp.tens.shape != []
- # check input and output tensors are compatible
- and inp.tens.format == outp.tens.format
- and inp.tens.dtype == outp.tens.dtype
- # check input tensor only has one consumer
- and len(inp.tens.consumer_list) == 1
- ):
- return True
-
- return False
-
def _estimate_sram_usage(self, sched_op, cost) -> int:
"""Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
ifm2_size = sched_op.ifm2_size_in_bytes()
@@ -155,17 +114,16 @@ class CascadeBuilder:
cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
* sched_op.ifm.dtype.size_in_bytes()
)
- if sched_op.requires_full_ofm:
+ if ofm_can_reuse_ifm(sched_op):
+ # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
+ ofm_size = 0
+ elif sched_op.requires_full_ofm:
ofm_size = sched_op.ofm_size_in_bytes()
else:
ofm_size = (
cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
)
- if self._is_mergeable(sched_op):
- # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
- ofm_size = 0
-
return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
@staticmethod