aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/scheduler.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r--ethosu/vela/scheduler.py949
1 files changed, 949 insertions, 0 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
new file mode 100644
index 00000000..c35c1566
--- /dev/null
+++ b/ethosu/vela/scheduler.py
@@ -0,0 +1,949 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Description:
+# The scheduler costs various strategies for scheduling the network in order to select the block configuration.
+
+import enum
+from .nn_graph import (
+ TensorPurpose,
+ TensorSubPurpose,
+ TensorFormat,
+ MemArea,
+ SchedulingStrategy,
+ CascadedPass,
+ PassPlacement,
+ SchedulerRewrite,
+ Operation,
+ NpuBlockType,
+)
+from . import live_range
+import numpy as np
+from . import npu_performance
+from . import stats_writer
+from .npu_performance import make_bandwidth_array, make_macs_array, make_cycles_array, make_metrics_arrays, PassCycles
+import time, copy
+from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list
+from .shared_buffer_allocation import (
+ find_block_configs_suitable_for_pass_and_shared_buffer,
+ shared_buffer_allocation_for_pass_and_block_config,
+)
+from functools import lru_cache
+
+
+class ParetoMetric(enum.Enum):
+ BwCycMem = 1
+ BwCycMemBlkH = 2
+
+ def __str__(self):
+ return self.name
+
+
+class SchedulerOptions:
+ def __init__(
+ self,
+ use_cascading=True,
+ use_ifm_ofm_overlap=True,
+ verbose_schedule=False,
+ verbose_pareto_frontier_schedules=False,
+ use_ifm_streaming=True,
+ pareto_metric=ParetoMetric.BwCycMem,
+ ):
+ self.use_cascading = use_cascading
+ self.use_ifm_ofm_overlap = use_ifm_ofm_overlap
+ self.verbose_schedule = verbose_schedule
+ self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules
+ self.use_ifm_streaming = use_ifm_streaming
+ self.pareto_metric = pareto_metric
+
+ def __str__(self):
+ return type(self).__name__ + ": " + str(self.__dict__)
+
+ __repr__ = __str__
+
+
+class Strategy:
+ __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used"
+
+ def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used):
+ self.strat = strat
+ self.param = param
+ self.passes = passes
+ self.block_configs = block_configs
+ self.rewrite_list = (
+ rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass)
+ )
+ self.bws = bws
+ self.macs = macs
+ self.cycles = cycles
+ self.sram_used = sram_used
+
+ def __eq__(self, other):
+ if self.strat != other.strat:
+ return False
+ if self.param != other.param:
+ return False
+ if self.block_configs != other.block_configs:
+ return False
+ if self.passes != other.passes:
+ return False
+ if (self.bws != other.bws).any():
+ return False
+ if (self.macs != other.macs).any():
+ return False
+ if (self.cycles != other.cycles).any():
+ return False
+ if self.sram_used != other.sram_used:
+ return False
+ return True
+
+ def empty(self):
+ return not self.passes
+
+ def key(self):
+ return self.passes[-1]
+
+ def clone(self):
+ return Strategy(
+ self.strat,
+ self.param,
+ self.passes,
+ self.block_configs,
+ self.rewrite_list,
+ self.bws,
+ self.macs,
+ self.cycles,
+ self.sram_used,
+ )
+
+ def __str__(self):
+ return "<scheduler.Strategy: %s %s %s %s %s %s %s>" % (
+ self.strat,
+ self.passes,
+ self.rewrite_list,
+ self.bws,
+ self.macs,
+ self.cycles,
+ self.sram_used,
+ )
+
+ __repr__ = __str__
+
+
+class StrategySet:
+ __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used"
+
+ def __init__(self, strats=None):
+ if strats is None:
+ strats = dict()
+ self.strats = strats # final pass in packed pass -> Strategy
+ self.bws, self.macs, self.cycles = make_metrics_arrays()
+ self.max_sram_used = 0
+ self.total_sram_used = 0
+
+ def update_statistics(self):
+ self.bws = make_bandwidth_array()
+ self.max_sram_used = 0
+ for ps, strat in self.strats.items():
+ self.bws += strat.bws
+ self.macs += strat.macs
+ self.cycles += strat.cycles
+ self.max_sram_used = max(self.max_sram_used, strat.sram_used)
+ self.total_sram_used += strat.sram_used
+
+ def clone_add_strategy(self, new_strat):
+ key = new_strat.key()
+ if key in self.strats:
+ assert new_strat == self.strats[key]
+ return self
+ else:
+ new_strats = dict(self.strats)
+ new_strats[key] = new_strat
+ new_set = StrategySet(new_strats)
+ new_set.bws = self.bws + new_strat.bws
+ new_set.macs = self.macs + new_strat.macs
+ new_set.cycles = self.cycles + new_strat.cycles
+ new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used)
+ new_set.total_sram_used = self.total_sram_used + new_strat.sram_used
+ return new_set
+
+ def __eq__(self, other):
+ if (self.bws != other.bws).any():
+ return False
+ if (self.macs != other.macs).any():
+ return False
+ if (self.cycles != other.cycles).any():
+ return False
+ if self.max_sram_used != other.max_sram_used:
+ return False
+ if self.total_sram_used != other.total_sram_used:
+ return False
+ if self.strats != other.strats:
+ return False
+ return True
+
+ def __str__(self):
+ return "<scheduler.StrategySet: max_sram_used=%s passes_covered=%s>" % (
+ self.max_sram_used,
+ list(ps.name for ps in self.strats),
+ )
+
+ __repr__ = __str__
+
+
+empty_strategy = Strategy(
+ SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
+)
+INFINITY = 1e30
+
+ABORT_SEARCH = []
+
+
+def flatten_list_of_lists(lstlst):
+ lst = []
+ for v in lstlst:
+ lst.extend(v)
+ return lst
+
+
+class DynamicProgrammingScheduler:
+ def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions):
+ self.nng = nng
+ self.sg = sg
+ self.arch = arch
+ self.sram_limit = sram_limit
+ self.options = copy.copy(options)
+ self.use_cascading = options.use_cascading
+
+ if self.arch.feature_map_storage_mem_area != MemArea.Sram:
+ self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM
+ self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap
+
+ self.verbose_schedule = options.verbose_schedule
+ self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules
+ self.mem_area = MemArea.Sram
+
+ self.bandwidth_weights = arch.bandwidth_weights
+ self.cycles_weight = arch.cycles_weight
+ self.max_sram_used_weight = arch.max_sram_used_weight
+
+ self.n_combinations_searched = 0
+
+ self.feature_maps_not_in_fast_storage = (
+ arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area
+ )
+
+ self.pareto_max_candidates = 16
+
+ self.ifm_stream_npu_blocks = set(
+ (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,)
+ )
+
+ num_pareto_metrics = 4
+ view_values = ",".join(["d"] * num_pareto_metrics)
+ order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)]
+
+ def pareto_metric(self, candidate):
+ strat, strat_set = candidate
+ total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
+ bws = strat.bws + strat_set.bws
+ last_block_height = 0
+ if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0:
+ last_block_height = strat.block_configs[-1][0]
+
+ return (
+ np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight,
+ strat_set.max_sram_used,
+ strat.sram_used,
+ last_block_height,
+ )
+
+ def filter_pareto_frontier(self, candidates, remove_equally_good_candidates):
+
+ candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit]
+
+ if len(candidates) <= 1:
+ return candidates
+ assert remove_equally_good_candidates
+ start = time.time()
+ pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics))
+ ids = np.arange(len(candidates), dtype=np.int32)
+ for idx, cand in enumerate(candidates):
+ pareto_vals[idx] = self.pareto_metric(cand)
+
+ sort_order = np.argsort(
+ pareto_vals.view(DynamicProgrammingScheduler.view_values),
+ order=DynamicProgrammingScheduler.order_values,
+ axis=0,
+ kind="stable",
+ ).flatten()
+ pareto_vals = pareto_vals[sort_order]
+ ids = ids[sort_order]
+
+ pareto_frontier = []
+ while len(ids) > 0:
+ pareto_frontier.append(candidates[ids[0]])
+ not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1)
+ ids = ids[not_dominated_by_first]
+ pareto_vals = pareto_vals[not_dominated_by_first]
+
+ if len(pareto_frontier) > self.pareto_max_candidates:
+ pareto_frontier = self.sort_by_candidate_metric(pareto_frontier)
+ pareto_frontier = pareto_frontier[: self.pareto_max_candidates]
+
+ return pareto_frontier
+
+ def candidate_metric(self, candidate):
+ strat, strat_set = candidate
+ max_sram_used = max(strat_set.max_sram_used, strat.sram_used)
+ bws = strat.bws + strat_set.bws
+ total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total]
+
+ return (
+ max_sram_used * self.max_sram_used_weight
+ + np.tensordot(bws, self.bandwidth_weights, axes=3)
+ + total_cycles * self.cycles_weight
+ )
+
+ def sort_by_candidate_metric(self, candidate_list):
+ sorted_list = list(sorted(candidate_list, key=self.candidate_metric))
+ return sorted_list
+
+ def best_candidate(self, candidate_list):
+ if len(candidate_list) == 0:
+ return ABORT_SEARCH
+ if len(candidate_list) == 1:
+ return candidate_list[0]
+ sorted_list = self.sort_by_candidate_metric(candidate_list)
+ return sorted_list[0]
+
+ def graduate_strat(self, strat_type, sram_used, old_strat_data):
+ res = []
+ for old_strat, old_strat_set in old_strat_data:
+ if old_strat.sram_used + sram_used > self.sram_limit:
+ continue # This strategy is bad, drop it
+ if old_strat_set.max_sram_used > self.sram_limit:
+ continue # This strategy is bad, drop it
+ assert old_strat.strat == SchedulingStrategy.Unknown
+
+ new_strat = old_strat.clone()
+ new_strat.strat = strat_type
+ new_strat.sram_used = old_strat.sram_used + sram_used
+
+ if self.use_ifm_ofm_overlap:
+ overlap = calc_allowed_ofm_ifm_overlap_for_pass_list(
+ new_strat.strat, new_strat.passes, new_strat.block_configs
+ )
+ new_strat.sram_used -= overlap
+
+ new_strat_set = old_strat_set.clone_add_strategy(new_strat)
+ res.append((empty_strategy, new_strat_set))
+ return self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
+
+ def append_sram(self, sram_used, old_strat_data):
+ res = []
+ for old_strat, strat_set in old_strat_data:
+ assert old_strat.strat == SchedulingStrategy.Unknown
+ assert old_strat.sram_used == 0
+ new_strat = old_strat.clone()
+ new_strat.sram_used = old_strat.sram_used + sram_used
+
+ res.append((new_strat, strat_set))
+ return res
+
+ def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data):
+ res = []
+ for old_strat, strat_set in old_strat_data:
+ assert old_strat.strat == SchedulingStrategy.Unknown
+ new_strat = old_strat.clone()
+ bws, macs, cycles = metrics[:3]
+
+ new_strat.sram_used = old_strat.sram_used + sram_used
+ new_strat.block_configs = old_strat.block_configs + [block_config]
+ new_strat.bws = old_strat.bws + bws
+ new_strat.macs = old_strat.macs + macs
+ new_strat.cycles = old_strat.cycles + cycles
+ new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
+ self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
+ )
+
+ res.append((new_strat, strat_set))
+ return res
+
+ def append_sram_pass_block_config_performance_metrics_rewrite_list(
+ self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data
+ ):
+ res = []
+ for old_strat, strat_set in old_strat_data:
+ assert old_strat.strat == SchedulingStrategy.Unknown
+ new_strat = old_strat.clone()
+ bws, macs, cycles = metrics[:3]
+ new_strat.sram_used = old_strat.sram_used + sram_used
+ new_strat.block_configs = old_strat.block_configs + [block_config]
+ new_strat.bws = old_strat.bws + bws
+ new_strat.macs = old_strat.macs + macs
+ new_strat.cycles = old_strat.cycles + cycles
+ new_strat.passes = old_strat.passes + [new_pass]
+ new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass(
+ self.arch, new_strat.bws, new_strat.macs, new_strat.cycles
+ )
+ new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
+ res.append((new_strat, strat_set))
+ return res
+
+ def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data):
+ res = []
+ for old_strat, strat_set in old_strat_data:
+ assert old_strat.strat == SchedulingStrategy.Unknown
+ new_strat = old_strat.clone()
+ new_strat.sram_used = old_strat.sram_used + sram_used
+ new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list
+ res.append((new_strat, strat_set))
+ return res
+
+ def pass_to_strat(self, strat_data):
+ res = {}
+ for strat in strat_data[1].strats.values():
+ for ps in strat.passes:
+ res[ps] = strat
+ return res
+
+ def compatible_strats(self, a, b):
+ intersection = a.keys() & b.keys()
+ for k in intersection:
+ if a[k] != b[k]:
+ return False
+ return True
+
+ def collate_strats_for_passes(self, all_passes):
+ if len(all_passes) == 0:
+ return [(empty_strategy, StrategySet(dict()))]
+ if len(all_passes) == 1:
+ return all_passes[0] # save some space in the common case
+ all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes]
+ prev_combos = [dict()]
+ for j, strand in enumerate(all_strands):
+ new_combos = []
+ for i, alt in enumerate(strand):
+ for prev in prev_combos:
+ if self.compatible_strats(prev, alt):
+ cmb = dict(prev)
+ cmb.update(all_passes[j][i][1].strats)
+ new_combos.append(cmb)
+ prev_combos = new_combos
+
+ res = []
+ for d in prev_combos:
+ s = StrategySet(d)
+ s.update_statistics()
+ res.append((empty_strategy, s))
+ return res
+
+ def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data):
+ # get the rest of the predecessors
+ other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass]
+ other_predecessor_data = self.search_pass_list(other_predecessors)
+
+ # pred strat data has an incomplete strategy, which we need
+ # to continue on, whereas the other ones have completed strategies.
+ # we need to merge these, but keep the incomplete strategy too.
+
+ res = []
+ for pred_pass_strat, pred_pass_strat_set in pred_pass_data:
+ all_strats = [
+ [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy
+ other_predecessor_data, # this one is fine to use as-is
+ ]
+ collated_strat_data = self.collate_strats_for_passes(all_strats)
+ strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data]
+ res.extend(strat_data)
+ return res
+
+ def calc_non_local_mem_usage(self):
+ ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu
+ range_set = live_range.extract_live_ranges_from_passes(
+ self.sg,
+ self.mem_area,
+ mark_output_tensors_overlapping_with_input_tensors=True,
+ ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors,
+ )
+ range_dict = range_set.ranges
+
+ # find which ranges overlap passes but aren't input/outputs of the passes.
+ # these won't be counted by the dynamic programming search and must be counted in manually.
+ end_pos = max(ps.time for ps in self.sg.passes) + 2
+ mem_usage = np.zeros(end_pos) + self.sg.base_sram_used
+ non_local_mem_usage = np.zeros(end_pos, dtype=np.int64)
+
+ for tens, rng in range_dict.items():
+ storage_size = tens.storage_size()
+ assert tens.mem_area == self.mem_area
+ mem_usage[rng.start_time : rng.end_time] += storage_size
+
+ for ps in self.sg.passes:
+ local_mem_usage = 0
+ for tens in ps.inputs + ps.outputs + ps.intermediates:
+ if tens.mem_area != self.mem_area:
+ continue
+
+ local_mem_usage += tens.storage_size()
+
+ non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage
+
+ self.non_local_mem_usage = non_local_mem_usage
+
+ def search(self):
+ self.calc_non_local_mem_usage()
+ starting_passes = [ps for ps in self.sg.passes if not ps.successors]
+ strat_data = self.search_pass_list(starting_passes)
+
+ _, best_set = self.best_candidate(strat_data)
+
+ if self.verbose_pareto_frontier_schedules:
+ print(
+ "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier"
+ % (self.n_combinations_searched, len(strat_data,))
+ )
+ for idx, (_, strat_set) in enumerate(strat_data):
+ extra = ""
+ if strat_set == best_set:
+ extra = "(Best candidate)"
+ print("Candidate", idx, extra)
+ memory_used = {MemArea.Sram: strat_set.max_sram_used}
+ stats_writer.print_performance_metrics_for_strat(
+ self.arch,
+ "",
+ strat_set.cycles,
+ strat_set.macs,
+ strat_set.bws,
+ self.nng.batch_size,
+ memory_used,
+ len(self.sg.passes),
+ len(strat_set.strats),
+ )
+
+ return best_set
+
+ def search_pass_list(self, pass_list):
+ all_strats = []
+ for ps in pass_list:
+ strat = self.search_output(ps)
+ all_strats.append(strat)
+ strat_data = self.collate_strats_for_passes(all_strats)
+ for strd in strat_data:
+ for ps in pass_list:
+ assert ps in strd[1].strats # should have strategies for everything we asked to search
+ return strat_data
+
+ def search_predecessors(self, ps):
+
+ # protect against graphs with loops. collate_strats_for_passes will sort this out later so that
+ # we have strats for all passes
+
+ pass_list = ps.dag_predecessors
+ strat_data = self.search_pass_list(pass_list)
+
+ return strat_data
+
+ @lru_cache(maxsize=None)
+ def search_output(self, ps):
+
+ assert ps in self.sg.passes
+ candidate_list = []
+
+ candidate_list.extend(self.search_weight_streaming_output(ps))
+
+ if self.options.use_ifm_streaming:
+ candidate_list.extend(self.search_ifm_streaming_output(ps))
+
+ best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True)
+
+ if not best:
+ print(
+ "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy"
+ % (ps.name,)
+ )
+ return self.search_predecessors(ps)
+
+ return best
+
+ def search_ifm_streaming_output(self, ps):
+ if ps.placement != PassPlacement.Npu:
+ return ABORT_SEARCH
+ if ps.npu_block_type not in self.ifm_stream_npu_blocks:
+ return ABORT_SEARCH
+ strat_data = self.search_ifm_streaming_body(ps, False)
+
+ sram_used = self.non_local_mem_usage[ps.time]
+ for tens in ps.outputs:
+ if tens.mem_area == self.mem_area:
+ sram_used += tens.storage_size()
+
+ return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data)
+
+ @lru_cache(maxsize=None)
+ def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage):
+ if ps.placement != PassPlacement.Npu:
+ return ABORT_SEARCH
+ if ps.npu_block_type not in self.ifm_stream_npu_blocks:
+ return ABORT_SEARCH
+ ifm_input_search_resuls = self.search_ifm_streaming_input(ps)
+ res = []
+
+ base_sram_used = 0
+ for tens in ps.intermediates:
+ if tens.mem_area == self.mem_area:
+ base_sram_used += tens.storage_size()
+
+ all_block_configs = self.get_block_configs(ps)
+ for block_config in all_block_configs:
+ all_strats = []
+
+ if self.use_cascading:
+ all_strats.extend(self.search_ifm_streaming_partial(ps, block_config))
+
+ all_strats.extend(ifm_input_search_resuls)
+
+ rewrite_list = []
+ sram_used = base_sram_used
+
+ metrics = npu_performance.performance_metrics_for_pass(
+ self.arch,
+ ps,
+ block_config,
+ rewrite_list=rewrite_list,
+ force_outputs_to_fast_storage=force_outputs_to_fast_storage,
+ )
+
+ res.extend(
+ self.append_sram_pass_block_config_performance_metrics_rewrite_list(
+ sram_used, ps, block_config, metrics, rewrite_list, all_strats
+ )
+ )
+
+ self.n_combinations_searched += len(res)
+ res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
+ return res
+
+ def search_ifm_streaming_partial(self, ps, block_config):
+ if ps.placement != PassPlacement.Npu:
+ return ABORT_SEARCH
+
+ if len(ps.inputs) < 1:
+ return ABORT_SEARCH
+
+ ifm_tensor = ps.ifm_tensor
+
+ if ifm_tensor is None:
+ return ABORT_SEARCH
+ if ifm_tensor.purpose != TensorPurpose.FeatureMap:
+ return ABORT_SEARCH
+ if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4:
+ return ABORT_SEARCH
+
+ pred_pass_list = []
+ for pred_candidate in ps.dag_predecessors:
+ if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor:
+ # we found a predecessor that produces this IFM tensor
+ if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps:
+ # and it only has one successor, namely us
+ if pred_candidate.placement == PassPlacement.Npu:
+ if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks:
+ # and it is on the Npu and fusable - it's a candidate
+ pred_pass_list.append(pred_candidate)
+
+ if not pred_pass_list:
+ return ABORT_SEARCH
+
+ all_candidates = []
+ for pred_pass in pred_pass_list:
+ # recurse into the next pass
+ ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage)
+
+ strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data)
+ for strat_opt in strat_data:
+
+ pred_pass_block_config = strat_opt[0].block_configs[-1]
+ rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes(
+ self.arch, pred_pass, pred_pass_block_config, ps, block_config
+ )
+ if rolling_buffer_dims is None:
+ continue # this does not pack properly, skip it.
+
+ sram_used = 0
+ for tens in ps.inputs:
+ if tens != ifm_tensor:
+ if tens.mem_area == self.mem_area:
+ sram_used += tens.storage_size()
+
+ rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims
+
+ rewrite_list = [
+ (
+ SchedulerRewrite.ChangeTensorSubPurpose,
+ ifm_tensor,
+ TensorSubPurpose.RollingBufferY,
+ rolling_buffer_y,
+ None,
+ ps,
+ )
+ ]
+ sram_used += ifm_tensor.storage_size_for_sub_purpose(
+ TensorSubPurpose.RollingBufferY, rolling_buffer_y, None
+ )
+
+ all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt]))
+
+ self.n_combinations_searched += len(all_candidates)
+ return all_candidates
+
+ def get_block_configs(self, ps):
+ if ps.placement != PassPlacement.Npu:
+ return [(1, 1, 1, 1)] # default
+
+ block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps)
+
+ # Take a limited number of the largest blocks
+ if self.arch.block_config_limit > 0:
+ # Sort by block area, followed by depth
+ block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True)
+ bound = min(len(block_configs), self.arch.block_config_limit)
+ # We take 'n' from the fat end of the list, and 'n' from the thin end of the list.
+ tmp = block_configs[:bound]
+ tmp.extend(block_configs[max(bound, len(block_configs) - bound) :])
+ block_configs = tmp
+
+ return block_configs
+
+ def search_ifm_streaming_input(self, ps):
+ sram_used = 0
+ for tens in ps.inputs:
+ if tens.mem_area == self.mem_area:
+ sram_used += tens.storage_size()
+
+ return self.append_sram(sram_used, self.search_predecessors(ps))
+
+ def search_weight_streaming_output(self, ps):
+ strat_data = self.search_weight_streaming_body(ps)
+
+ sram_used = self.non_local_mem_usage[ps.time]
+ for tens in ps.outputs:
+ if tens.mem_area == self.mem_area:
+ sram_used += tens.storage_size()
+
+ return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data)
+
+ @lru_cache(maxsize=None)
+ def search_weight_streaming_body(self, ps):
+
+ strat_data = self.search_weight_streaming_input(ps)
+
+ res = []
+
+ all_block_configs = self.get_block_configs(ps)
+
+ for block_config in all_block_configs:
+
+ sram_used = 0
+ rewrite_list = []
+
+ for tens in ps.intermediates:
+ if tens.mem_area == self.mem_area:
+ if tens.purpose == TensorPurpose.Weights:
+ sram_used += tens.storage_size_for_sub_purpose(
+ TensorSubPurpose.DoubleBuffer, block_config[3]
+ )
+ rewrite_list.append(
+ (
+ SchedulerRewrite.ChangeTensorSubPurpose,
+ tens,
+ TensorSubPurpose.DoubleBuffer,
+ block_config[3],
+ None,
+ ps,
+ )
+ )
+ else:
+ sram_used += tens.storage_size()
+
+ metrics = npu_performance.performance_metrics_for_pass(
+ self.arch, ps, block_config, rewrite_list=rewrite_list
+ )
+
+ res.extend(
+ self.append_sram_pass_block_config_performance_metrics_rewrite_list(
+ sram_used, ps, block_config, metrics, rewrite_list, strat_data
+ )
+ )
+
+ self.n_combinations_searched += len(res)
+ res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True)
+ return res
+
+ def search_weight_streaming_input(self, ps):
+ sram_used = 0
+ for tens in ps.inputs:
+ if tens.mem_area == self.mem_area:
+ sram_used += tens.storage_size()
+
+ return self.append_sram(sram_used, self.search_predecessors(ps))
+
+ def apply_result(self, strat_set, arch):
+ pass_to_cascaded_pass = dict()
+ for _, strat in strat_set.strats.items():
+ # rewrite the tensors that need this first. e.g. make rolling buffers
+ inputs = []
+ intermediates = []
+ outputs = []
+
+ for ps in strat.passes:
+ inputs += ps.inputs
+ intermediates += ps.intermediates
+ outputs += ps.outputs
+
+ for tens in set(inputs) & set(outputs):
+ # tensors that are in both sets are intermediates
+
+ # find pass with input/output tensor, and check if they are both placed on NPU
+ input_placement = None
+ output_placement = None
+ for ps in strat.passes:
+ if tens in ps.inputs:
+ input_placement = ps.placement
+ if tens in ps.outputs:
+ output_placement = ps.placement
+ if input_placement == output_placement == PassPlacement.Npu:
+ tens.set_format(TensorFormat.NHCWB16, arch)
+
+ intermediates.append(tens)
+ inputs.remove(tens)
+ outputs.remove(tens)
+
+ for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list:
+ if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose:
+ tens.mem_area = self.arch.fast_storage_mem_area
+ tens.set_new_sub_purpose(sub_purpose, param_a, param_b)
+ else:
+ assert 0, "unknown rewrite_op " + str(rewrite_op)
+
+ is_element_wise = True
+ for ps in strat.passes:
+ assert ps.placement == strat.passes[0].placement
+ if not ps.is_element_wise:
+ is_element_wise = False
+ break
+
+ cascaded_pass = CascadedPass(
+ strat.passes[0].name,
+ strat.strat,
+ inputs,
+ intermediates,
+ outputs,
+ strat.passes,
+ strat.passes[0].placement,
+ is_element_wise,
+ )
+ assert strat.sram_used >= 0
+ cascaded_pass.sram_used = strat.sram_used
+
+ for idx, ps in enumerate(strat.passes):
+ assert ps not in pass_to_cascaded_pass
+ pass_to_cascaded_pass[ps] = cascaded_pass
+ ps.cascade = cascaded_pass
+ ps.block_config = strat.block_configs[idx]
+
+ if ps.placement == PassPlacement.Npu:
+ ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config(
+ self.arch, ps, ps.block_config
+ )
+ assert ps.shared_buffer is not None
+
+ for op in ps.ops:
+ subgraph = op.attrs.get("subgraph")
+ if subgraph:
+ subgraph.base_sram_used = cascaded_pass.sram_used
+
+ # all passes should have a cascaded pass now
+ if len(pass_to_cascaded_pass) != len(self.sg.passes):
+ print(
+ "mismatch: we have %d passes, but only %d have cascaded passes associated"
+ % (len(self.sg.passes), len(pass_to_cascaded_pass))
+ )
+ for ps in self.sg.passes:
+ if not ps in pass_to_cascaded_pass:
+ print("%3d pass missing cascaded pass %s" % (ps.time, ps))
+
+ assert len(pass_to_cascaded_pass) == len(self.sg.passes)
+ # we have all the passes, but we need to put them in order and build predecessor/successor links.
+
+ visit_pass_set = set()
+ cascaded_passes = []
+
+ def visit_pass(ps):
+ if ps in visit_pass_set:
+ return
+ visit_pass_set.add(ps)
+
+ cps = ps.cascade
+ dont_traverse = set(cps.passes)
+
+ for ps in cps.passes:
+ for pred in ps.predecessors:
+ if pred in dont_traverse:
+ continue
+ visit_pass(pred)
+
+ cascaded_passes.append(cps)
+
+ starting_passes = [ps for ps in self.sg.passes if not ps.successors]
+ for ps in starting_passes:
+ visit_pass(ps)
+
+ # reorder so startup init cascaded passes come first
+ def is_startup_cascaded_pass(cps):
+ if not cps.passes:
+ return False
+ return cps.placement == PassPlacement.StartupInit
+
+ cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [
+ cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps)
+ ]
+
+ self.sg.cascaded_passes = cascaded_passes
+ self.sg.build_cascaded_pass_links()
+
+
+def schedule_passes(nng, arch, options: SchedulerOptions):
+
+ for sg in nng.subgraphs:
+ sg.base_sram_used = 0
+
+ for sg in nng.subgraphs:
+ # re-entering the same nodes from different contexts requires us to
+ # build a simplified directed acyclic (DAG) version of the graph to
+ # use for traversal, rather than using a visit dictionary. this avoids
+ # recursing infinitely due to loops.
+ sg.build_pass_dag_predecessors()
+
+ dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options)
+
+ strat_set = dps.search()
+
+ dps.apply_result(strat_set, arch)
+
+ if options.verbose_schedule:
+ sg.print_cascaded_passes()