aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/cascade_builder.py94
-rw-r--r--ethosu/vela/scheduler.py27
2 files changed, 117 insertions, 4 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 09c36b9e..09ee73e8 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -16,6 +16,8 @@
#
# Description:
# Groups Operators in a schedule together to form Cascades.
+from collections import namedtuple
+
from .numeric_util import round_up
from .operation import NpuBlockType
from .operation import Op
@@ -98,6 +100,46 @@ class CascadeBuilder:
and self.element_wise_cascading_conformity(sched_op)
)
+ def _is_mergeable(self, sched_op) -> bool:
+ # Code based on merge_elementwise_op_ranges from live_range.py
+
+ if not sched_op.op_type.is_elementwise_op():
+ return False
+
+ elem_op = sched_op.parent_op
+
+ # Check if overwriting the inputs can be allowed
+ OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"])
+ outp = OpShapeTens(elem_op.ofm_shapes[0], elem_op.ofm)
+
+ # check output tensor only has one producer
+ if len(outp.tens.ops) != 1:
+ return False
+
+ inps = []
+ if elem_op.ifm is not None:
+ inps.append(OpShapeTens(elem_op.ifm_shapes[0], elem_op.ifm))
+ if elem_op.ifm2 is not None:
+ inps.append(OpShapeTens(elem_op.ifm_shapes[1], elem_op.ifm2))
+
+ # find an input tensor that can be overwritten by the output
+ for inp in inps:
+ if (
+ # check op input and output shapes allow overlapping
+ inp.op_shape == outp.op_shape
+ # check input tensor is valid
+ and inp.tens is not None
+ and inp.tens.shape != []
+ # check input and output tensors are compatible
+ and inp.tens.format == outp.tens.format
+ and inp.tens.dtype == outp.tens.dtype
+ # check input tensor only has one consumer
+ and len(inp.tens.consumer_list) == 1
+ ):
+ return True
+
+ return False
+
def _estimate_sram_usage(self, sched_op, cost) -> int:
"""Estimate the SRAM required for the Op if all FeatureMaps are in SRAM"""
ifm2_size = sched_op.ifm2_size_in_bytes()
@@ -115,6 +157,10 @@ class CascadeBuilder:
cost.stripe.with_depth(round_up(cost.stripe.depth, 16)).elements() * sched_op.ofm.dtype.size_in_bytes()
)
+ if self._is_mergeable(sched_op):
+ # ofm will use the ifm buffer to reduce SRAM usage, hence ofm_size = 0
+ ofm_size = 0
+
return ifm_size + ifm2_size + ofm_size + self.non_local_mem_usage.get(sched_op, 0)
@staticmethod
@@ -245,8 +291,52 @@ class CascadeBuilder:
# Both the existing cascade and current Op fits
break
- # Determine if current cascade is the best so far
- if cascade_size < best_cascade_size:
+ """
+ One of two conditions will update the best cascade:
+
+ - cascade_size < best_cascade_size or
+ - cascade_size < uncascaded_sram_usage
+
+ The last condition is illustrated below, showing an example where it is
+ better to choose a larger cascade_size (with more OPs) because it will
+ use less total SRAM usage.
+
+ For simplicity, all featuremaps have same size.
+
+ Cascade OP1-OP2, OP3 is standalone
+
+ -> |OP1| -> roll buffer -> |OP2| -> FM -> |OP3| -> FM
+ /
+ |OP0| -> FM
+ \
+ -> ....
+
+
+ best_cascade_size : FM + roll buffer + FM
+ uncascaded_sram_usage: FM + FM + FM
+
+ compared with:
+
+ Cascade OP1-OP3
+
+ -> |OP1| -> roll buffer -> |OP2| -> roll buffer -> |OP3| -> FM
+ /
+ |OP0| -> FM
+ \
+ -> ....
+
+
+ cascade_size : FM + roll buffer + roll buffer + FM
+
+
+ So, for this use case the comparison will be
+
+ (FM + roll buffer + roll buffer + FM) < (FM + roll buffer + FM) or
+ (FM + roll buffer + roll buffer + FM) < (FM + FM + FM)
+
+ hence, better to choose Cascade OP1-OP3 in this case.
+ """
+ if cascade_size < best_cascade_size or cascade_size < uncascaded_sram_usage:
best_cascade_size = cascade_ifm_size + cascade_buffers + op_full_ofm
ops_in_best_cascade = [op for op in ops_in_cascade]
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index d01942bb..e9f38b4d 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -941,6 +941,29 @@ class Scheduler:
return peak_mem_usage
+ def build_cascades_for_min_schedule(self, min_schedule: Schedule, max_template: Schedule, memory_limit: int):
+ # Update memory snapshot
+ self.sg.schedule = min_schedule
+ self.update_op_memory_snapshot(min_schedule)
+
+ # Calculate residual memory for Min schedule
+ non_local_mem_usage = {}
+ for sched_op in self.sched_ops:
+ time_index = min_schedule.cost_map[sched_op].time_index
+
+ if self.arch.is_spilling_enabled():
+ # For Dedicated SRAM only the intermediate buffers are in SRAM, hence op_mem_usage is 0
+ op_mem_usage = 0
+ else:
+ # Min schedule only have ifm and ofm in SRAM (no buffered weigth tensors)
+ op_mem_usage = sched_op.ifm_size_in_bytes() + sched_op.ofm_size_in_bytes()
+
+ non_local_mem_usage[sched_op] = min_schedule.memory_snapshot[time_index] - op_mem_usage
+
+ # Crate cascades for Min schedule
+ cascade_builder = CascadeBuilder(self.sched_ops, self.arch.is_spilling_enabled(), non_local_mem_usage)
+ cascade_builder.build_cascades(min_schedule, max_template, memory_limit)
+
def optimize_sub_schedule(
self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int
) -> Schedule:
@@ -1545,8 +1568,8 @@ def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_o
if scheduler_options.optimization_strategy == OptimizationStrategy.Size:
initial_sram_limit = scheduler.min_memory_req
- cascade_builder = CascadeBuilder(scheduler.sched_ops, arch.is_spilling_enabled())
- cascade_builder.build_cascades(min_schedule, max_schedule_template, initial_sram_limit)
+ # Build cascades for Min schedule
+ scheduler.build_cascades_for_min_schedule(min_schedule, max_schedule_template, initial_sram_limit)
sg.schedule = min_schedule
scheduler.update_op_memory_snapshot(min_schedule)