aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-10-11 20:41:41 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-20 10:45:52 +0200
commitfba0a7dc43373a69f3c0792587d3d9b0cc010ccf (patch)
treebf1a497d8ca6c080988ab920110981df2f8badff
parent673683bb828cd552f1970922e3c61079607332b2 (diff)
downloadethos-u-vela-fba0a7dc43373a69f3c0792587d3d9b0cc010ccf.tar.gz
MLBEDSW-6931: Refactoring merge elementwise ops
Change code in cascade builder to instead use common functionality in live range. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I7bbd7ea3d1e7e085813e9d93256a54e6bab2267b
-rw-r--r--ethosu/vela/cascade_builder.py52
-rw-r--r--ethosu/vela/live_range.py22
2 files changed, 24 insertions, 50 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index ebe2f133..9c84ba8d 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -16,9 +16,8 @@
#
# Description:
# Groups Operators in a schedule together to form Cascades.
-from collections import namedtuple
-
from .high_level_command_to_npu_op import ifm_ifm2_correct_order
+from .live_range import ofm_can_reuse_ifm
from .numeric_util import round_up
from .operation import NpuBlockType
from .operation import Op
@@ -105,46 +104,6 @@ class CascadeBuilder:
and sched_op.parent_op.attrs.get("padding", None) != Padding.TILE
)
- def _is_mergeable(self, sched_op) -> bool:
- # Code based on merge_elementwise_op_ranges from live_range.py
-
- if not sched_op.op_type.is_elementwise_op():
- return False
-
- elem_op = sched_op.parent_op
-
- # Check if overwriting the inputs can be allowed
- OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
- outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
-
- # check output tensor only has one producer
- if len(outp.tens.ops) != 1:
- return False
-
- inps = []
- if elem_op.ifm is not None:
- inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
- if elem_op.ifm2 is not None:
- inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
-
- # find an input tensor that can be overwritten by the output
- for inp in inps:
- if (
- # check op input and output shapes allow overlapping
- inp.op_shape == outp.op_shape
- # check input tensor is valid
- and inp.tens is not None
- and inp.tens.shape != []
- # check input and output tensors are compatible
- and inp.tens.format == outp.tens.format
- and inp.tens.dtype == outp.tens.dtype
- # check input tensor only has one consumer
- and len(inp.tens.consumer_list) == 1
- ):
- return True
-
- return False
-
def _estimate_sram_usage(self, sched_op, cost) -> int:
"""Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
ifm2_size = sched_op.ifm2_size_in_bytes()
@@ -155,17 +114,16 @@ class CascadeBuilder:
cost.stripe_input.with_depth(round_up(cost.stripe_input.depth, 16)).elements()
* sched_op.ifm.dtype.size_in_bytes()
)
- if sched_op.requires_full_ofm:
+ if ofm_can_reuse_ifm(sched_op):
+ # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
+ ofm_size = 0
+ elif sched_op.requires_full_ofm:
ofm_size = sched_op.ofm_size_in_bytes()
else:
ofm_size = (
cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
)
- if self._is_mergeable(sched_op):
- # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
- ofm_size = 0
-
return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
@staticmethod
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 9b6fe63d..fbb48ecd 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -154,18 +154,21 @@ class LiveRangeGraph:
def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+ if target_mem_area is None or target_mem_type_set is None:
+ return False
if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
return True
return False
-def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
+def _get_ifm_to_fuse(sched_op, target_mem_area=None, target_mem_type_set=None):
def _tensor_should_be_ignored(tens):
if tens.ifm_write_protected:
return True
return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set)
- # Tries to merge ifm/ofm live ranges of elementwise op
+ # Check if possible to merge ifm/ofm live ranges of elementwise op
+ ifm_tens = None
if sched_op.op_type.is_elementwise_op():
elem_op = sched_op.parent_op
if not _tensor_should_be_ignored(elem_op.ofm):
@@ -195,9 +198,22 @@ def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_
# check output tensor only has one producer
and len(outp.tens.ops) == 1
):
- lr_graph.fuse_ranges(inp.tens, outp.tens)
+ ifm_tens = inp.tens
break
+ return ifm_tens
+
+
+def ofm_can_reuse_ifm(sched_op, target_mem_area=None, target_mem_type_set=None):
+ ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set)
+ return ifm is not None
+
+
+def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set):
+ ifm = _get_ifm_to_fuse(sched_op, target_mem_area, target_mem_type_set)
+ if ifm:
+ lr_graph.fuse_ranges(ifm, sched_op.parent_op.ofm)
+
def extract_live_ranges_from_cascaded_passes(
sg,