From 79d07d2cbf1c5013ab40bb46a6ccd4c569966536 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Mon, 27 Apr 2020 18:20:16 +0100 Subject: Add Vela codebase - Added modules ethosu.vela and ethosu.mlw_codec. - Added README and various configuration files. Change-Id: I3690f8c8f5966306ecddaeb2793c30ca9c6e2eee --- ethosu/vela/scheduler.py | 949 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 949 insertions(+) create mode 100644 ethosu/vela/scheduler.py (limited to 'ethosu/vela/scheduler.py') diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py new file mode 100644 index 00000000..c35c1566 --- /dev/null +++ b/ethosu/vela/scheduler.py @@ -0,0 +1,949 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Description: +# The scheduler costs various strategies for scheduling the network in order to select the block configuration. + +import enum +from .nn_graph import ( + TensorPurpose, + TensorSubPurpose, + TensorFormat, + MemArea, + SchedulingStrategy, + CascadedPass, + PassPlacement, + SchedulerRewrite, + Operation, + NpuBlockType, +) +from . import live_range +import numpy as np +from . import npu_performance +from . import stats_writer +from .npu_performance import make_bandwidth_array, make_macs_array, make_cycles_array, make_metrics_arrays, PassCycles +import time, copy +from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list +from .shared_buffer_allocation import ( + find_block_configs_suitable_for_pass_and_shared_buffer, + shared_buffer_allocation_for_pass_and_block_config, +) +from functools import lru_cache + + +class ParetoMetric(enum.Enum): + BwCycMem = 1 + BwCycMemBlkH = 2 + + def __str__(self): + return self.name + + +class SchedulerOptions: + def __init__( + self, + use_cascading=True, + use_ifm_ofm_overlap=True, + verbose_schedule=False, + verbose_pareto_frontier_schedules=False, + use_ifm_streaming=True, + pareto_metric=ParetoMetric.BwCycMem, + ): + self.use_cascading = use_cascading + self.use_ifm_ofm_overlap = use_ifm_ofm_overlap + self.verbose_schedule = verbose_schedule + self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules + self.use_ifm_streaming = use_ifm_streaming + self.pareto_metric = pareto_metric + + def __str__(self): + return type(self).__name__ + ": " + str(self.__dict__) + + __repr__ = __str__ + + +class Strategy: + __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used" + + def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used): + self.strat = strat + self.param = param + self.passes = passes + self.block_configs = block_configs + self.rewrite_list = ( + rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass) + ) + self.bws = bws + self.macs = macs + self.cycles = cycles + self.sram_used = sram_used + + def __eq__(self, other): + if self.strat != other.strat: + return False + if self.param != other.param: + return False + if self.block_configs != other.block_configs: + return False + if self.passes != other.passes: + return False + if (self.bws != other.bws).any(): + return False + if (self.macs != other.macs).any(): + return False + if (self.cycles != other.cycles).any(): + return False + if self.sram_used != other.sram_used: + return False + return True + + def empty(self): + return not self.passes + + def key(self): + return self.passes[-1] + + def clone(self): + return Strategy( + self.strat, + self.param, + self.passes, + self.block_configs, + self.rewrite_list, + self.bws, + self.macs, + self.cycles, + self.sram_used, + ) + + def __str__(self): + return "" % ( + self.strat, + self.passes, + self.rewrite_list, + self.bws, + self.macs, + self.cycles, + self.sram_used, + ) + + __repr__ = __str__ + + +class StrategySet: + __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used" + + def __init__(self, strats=None): + if strats is None: + strats = dict() + self.strats = strats # final pass in packed pass -> Strategy + self.bws, self.macs, self.cycles = make_metrics_arrays() + self.max_sram_used = 0 + self.total_sram_used = 0 + + def update_statistics(self): + self.bws = make_bandwidth_array() + self.max_sram_used = 0 + for ps, strat in self.strats.items(): + self.bws += strat.bws + self.macs += strat.macs + self.cycles += strat.cycles + self.max_sram_used = max(self.max_sram_used, strat.sram_used) + self.total_sram_used += strat.sram_used + + def clone_add_strategy(self, new_strat): + key = new_strat.key() + if key in self.strats: + assert new_strat == self.strats[key] + return self + else: + new_strats = dict(self.strats) + new_strats[key] = new_strat + new_set = StrategySet(new_strats) + new_set.bws = self.bws + new_strat.bws + new_set.macs = self.macs + new_strat.macs + new_set.cycles = self.cycles + new_strat.cycles + new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used) + new_set.total_sram_used = self.total_sram_used + new_strat.sram_used + return new_set + + def __eq__(self, other): + if (self.bws != other.bws).any(): + return False + if (self.macs != other.macs).any(): + return False + if (self.cycles != other.cycles).any(): + return False + if self.max_sram_used != other.max_sram_used: + return False + if self.total_sram_used != other.total_sram_used: + return False + if self.strats != other.strats: + return False + return True + + def __str__(self): + return "" % ( + self.max_sram_used, + list(ps.name for ps in self.strats), + ) + + __repr__ = __str__ + + +empty_strategy = Strategy( + SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0 +) +INFINITY = 1e30 + +ABORT_SEARCH = [] + + +def flatten_list_of_lists(lstlst): + lst = [] + for v in lstlst: + lst.extend(v) + return lst + + +class DynamicProgrammingScheduler: + def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions): + self.nng = nng + self.sg = sg + self.arch = arch + self.sram_limit = sram_limit + self.options = copy.copy(options) + self.use_cascading = options.use_cascading + + if self.arch.feature_map_storage_mem_area != MemArea.Sram: + self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM + self.use_ifm_ofm_overlap = options.use_ifm_ofm_overlap + + self.verbose_schedule = options.verbose_schedule + self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules + self.mem_area = MemArea.Sram + + self.bandwidth_weights = arch.bandwidth_weights + self.cycles_weight = arch.cycles_weight + self.max_sram_used_weight = arch.max_sram_used_weight + + self.n_combinations_searched = 0 + + self.feature_maps_not_in_fast_storage = ( + arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] != arch.fast_storage_mem_area + ) + + self.pareto_max_candidates = 16 + + self.ifm_stream_npu_blocks = set( + (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,) + ) + + num_pareto_metrics = 4 + view_values = ",".join(["d"] * num_pareto_metrics) + order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)] + + def pareto_metric(self, candidate): + strat, strat_set = candidate + total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total] + bws = strat.bws + strat_set.bws + last_block_height = 0 + if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0: + last_block_height = strat.block_configs[-1][0] + + return ( + np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight, + strat_set.max_sram_used, + strat.sram_used, + last_block_height, + ) + + def filter_pareto_frontier(self, candidates, remove_equally_good_candidates): + + candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit] + + if len(candidates) <= 1: + return candidates + assert remove_equally_good_candidates + start = time.time() + pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics)) + ids = np.arange(len(candidates), dtype=np.int32) + for idx, cand in enumerate(candidates): + pareto_vals[idx] = self.pareto_metric(cand) + + sort_order = np.argsort( + pareto_vals.view(DynamicProgrammingScheduler.view_values), + order=DynamicProgrammingScheduler.order_values, + axis=0, + kind="stable", + ).flatten() + pareto_vals = pareto_vals[sort_order] + ids = ids[sort_order] + + pareto_frontier = [] + while len(ids) > 0: + pareto_frontier.append(candidates[ids[0]]) + not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1) + ids = ids[not_dominated_by_first] + pareto_vals = pareto_vals[not_dominated_by_first] + + if len(pareto_frontier) > self.pareto_max_candidates: + pareto_frontier = self.sort_by_candidate_metric(pareto_frontier) + pareto_frontier = pareto_frontier[: self.pareto_max_candidates] + + return pareto_frontier + + def candidate_metric(self, candidate): + strat, strat_set = candidate + max_sram_used = max(strat_set.max_sram_used, strat.sram_used) + bws = strat.bws + strat_set.bws + total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total] + + return ( + max_sram_used * self.max_sram_used_weight + + np.tensordot(bws, self.bandwidth_weights, axes=3) + + total_cycles * self.cycles_weight + ) + + def sort_by_candidate_metric(self, candidate_list): + sorted_list = list(sorted(candidate_list, key=self.candidate_metric)) + return sorted_list + + def best_candidate(self, candidate_list): + if len(candidate_list) == 0: + return ABORT_SEARCH + if len(candidate_list) == 1: + return candidate_list[0] + sorted_list = self.sort_by_candidate_metric(candidate_list) + return sorted_list[0] + + def graduate_strat(self, strat_type, sram_used, old_strat_data): + res = [] + for old_strat, old_strat_set in old_strat_data: + if old_strat.sram_used + sram_used > self.sram_limit: + continue # This strategy is bad, drop it + if old_strat_set.max_sram_used > self.sram_limit: + continue # This strategy is bad, drop it + assert old_strat.strat == SchedulingStrategy.Unknown + + new_strat = old_strat.clone() + new_strat.strat = strat_type + new_strat.sram_used = old_strat.sram_used + sram_used + + if self.use_ifm_ofm_overlap: + overlap = calc_allowed_ofm_ifm_overlap_for_pass_list( + new_strat.strat, new_strat.passes, new_strat.block_configs + ) + new_strat.sram_used -= overlap + + new_strat_set = old_strat_set.clone_add_strategy(new_strat) + res.append((empty_strategy, new_strat_set)) + return self.filter_pareto_frontier(res, remove_equally_good_candidates=True) + + def append_sram(self, sram_used, old_strat_data): + res = [] + for old_strat, strat_set in old_strat_data: + assert old_strat.strat == SchedulingStrategy.Unknown + assert old_strat.sram_used == 0 + new_strat = old_strat.clone() + new_strat.sram_used = old_strat.sram_used + sram_used + + res.append((new_strat, strat_set)) + return res + + def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data): + res = [] + for old_strat, strat_set in old_strat_data: + assert old_strat.strat == SchedulingStrategy.Unknown + new_strat = old_strat.clone() + bws, macs, cycles = metrics[:3] + + new_strat.sram_used = old_strat.sram_used + sram_used + new_strat.block_configs = old_strat.block_configs + [block_config] + new_strat.bws = old_strat.bws + bws + new_strat.macs = old_strat.macs + macs + new_strat.cycles = old_strat.cycles + cycles + new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass( + self.arch, new_strat.bws, new_strat.macs, new_strat.cycles + ) + + res.append((new_strat, strat_set)) + return res + + def append_sram_pass_block_config_performance_metrics_rewrite_list( + self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data + ): + res = [] + for old_strat, strat_set in old_strat_data: + assert old_strat.strat == SchedulingStrategy.Unknown + new_strat = old_strat.clone() + bws, macs, cycles = metrics[:3] + new_strat.sram_used = old_strat.sram_used + sram_used + new_strat.block_configs = old_strat.block_configs + [block_config] + new_strat.bws = old_strat.bws + bws + new_strat.macs = old_strat.macs + macs + new_strat.cycles = old_strat.cycles + cycles + new_strat.passes = old_strat.passes + [new_pass] + new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass( + self.arch, new_strat.bws, new_strat.macs, new_strat.cycles + ) + new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list + res.append((new_strat, strat_set)) + return res + + def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data): + res = [] + for old_strat, strat_set in old_strat_data: + assert old_strat.strat == SchedulingStrategy.Unknown + new_strat = old_strat.clone() + new_strat.sram_used = old_strat.sram_used + sram_used + new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list + res.append((new_strat, strat_set)) + return res + + def pass_to_strat(self, strat_data): + res = {} + for strat in strat_data[1].strats.values(): + for ps in strat.passes: + res[ps] = strat + return res + + def compatible_strats(self, a, b): + intersection = a.keys() & b.keys() + for k in intersection: + if a[k] != b[k]: + return False + return True + + def collate_strats_for_passes(self, all_passes): + if len(all_passes) == 0: + return [(empty_strategy, StrategySet(dict()))] + if len(all_passes) == 1: + return all_passes[0] # save some space in the common case + all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes] + prev_combos = [dict()] + for j, strand in enumerate(all_strands): + new_combos = [] + for i, alt in enumerate(strand): + for prev in prev_combos: + if self.compatible_strats(prev, alt): + cmb = dict(prev) + cmb.update(all_passes[j][i][1].strats) + new_combos.append(cmb) + prev_combos = new_combos + + res = [] + for d in prev_combos: + s = StrategySet(d) + s.update_statistics() + res.append((empty_strategy, s)) + return res + + def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data): + # get the rest of the predecessors + other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass] + other_predecessor_data = self.search_pass_list(other_predecessors) + + # pred strat data has an incomplete strategy, which we need + # to continue on, whereas the other ones have completed strategies. + # we need to merge these, but keep the incomplete strategy too. + + res = [] + for pred_pass_strat, pred_pass_strat_set in pred_pass_data: + all_strats = [ + [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy + other_predecessor_data, # this one is fine to use as-is + ] + collated_strat_data = self.collate_strats_for_passes(all_strats) + strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data] + res.extend(strat_data) + return res + + def calc_non_local_mem_usage(self): + ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu + range_set = live_range.extract_live_ranges_from_passes( + self.sg, + self.mem_area, + mark_output_tensors_overlapping_with_input_tensors=True, + ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors, + ) + range_dict = range_set.ranges + + # find which ranges overlap passes but aren't input/outputs of the passes. + # these won't be counted by the dynamic programming search and must be counted in manually. + end_pos = max(ps.time for ps in self.sg.passes) + 2 + mem_usage = np.zeros(end_pos) + self.sg.base_sram_used + non_local_mem_usage = np.zeros(end_pos, dtype=np.int64) + + for tens, rng in range_dict.items(): + storage_size = tens.storage_size() + assert tens.mem_area == self.mem_area + mem_usage[rng.start_time : rng.end_time] += storage_size + + for ps in self.sg.passes: + local_mem_usage = 0 + for tens in ps.inputs + ps.outputs + ps.intermediates: + if tens.mem_area != self.mem_area: + continue + + local_mem_usage += tens.storage_size() + + non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage + + self.non_local_mem_usage = non_local_mem_usage + + def search(self): + self.calc_non_local_mem_usage() + starting_passes = [ps for ps in self.sg.passes if not ps.successors] + strat_data = self.search_pass_list(starting_passes) + + _, best_set = self.best_candidate(strat_data) + + if self.verbose_pareto_frontier_schedules: + print( + "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier" + % (self.n_combinations_searched, len(strat_data,)) + ) + for idx, (_, strat_set) in enumerate(strat_data): + extra = "" + if strat_set == best_set: + extra = "(Best candidate)" + print("Candidate", idx, extra) + memory_used = {MemArea.Sram: strat_set.max_sram_used} + stats_writer.print_performance_metrics_for_strat( + self.arch, + "", + strat_set.cycles, + strat_set.macs, + strat_set.bws, + self.nng.batch_size, + memory_used, + len(self.sg.passes), + len(strat_set.strats), + ) + + return best_set + + def search_pass_list(self, pass_list): + all_strats = [] + for ps in pass_list: + strat = self.search_output(ps) + all_strats.append(strat) + strat_data = self.collate_strats_for_passes(all_strats) + for strd in strat_data: + for ps in pass_list: + assert ps in strd[1].strats # should have strategies for everything we asked to search + return strat_data + + def search_predecessors(self, ps): + + # protect against graphs with loops. collate_strats_for_passes will sort this out later so that + # we have strats for all passes + + pass_list = ps.dag_predecessors + strat_data = self.search_pass_list(pass_list) + + return strat_data + + @lru_cache(maxsize=None) + def search_output(self, ps): + + assert ps in self.sg.passes + candidate_list = [] + + candidate_list.extend(self.search_weight_streaming_output(ps)) + + if self.options.use_ifm_streaming: + candidate_list.extend(self.search_ifm_streaming_output(ps)) + + best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True) + + if not best: + print( + "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy" + % (ps.name,) + ) + return self.search_predecessors(ps) + + return best + + def search_ifm_streaming_output(self, ps): + if ps.placement != PassPlacement.Npu: + return ABORT_SEARCH + if ps.npu_block_type not in self.ifm_stream_npu_blocks: + return ABORT_SEARCH + strat_data = self.search_ifm_streaming_body(ps, False) + + sram_used = self.non_local_mem_usage[ps.time] + for tens in ps.outputs: + if tens.mem_area == self.mem_area: + sram_used += tens.storage_size() + + return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data) + + @lru_cache(maxsize=None) + def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage): + if ps.placement != PassPlacement.Npu: + return ABORT_SEARCH + if ps.npu_block_type not in self.ifm_stream_npu_blocks: + return ABORT_SEARCH + ifm_input_search_resuls = self.search_ifm_streaming_input(ps) + res = [] + + base_sram_used = 0 + for tens in ps.intermediates: + if tens.mem_area == self.mem_area: + base_sram_used += tens.storage_size() + + all_block_configs = self.get_block_configs(ps) + for block_config in all_block_configs: + all_strats = [] + + if self.use_cascading: + all_strats.extend(self.search_ifm_streaming_partial(ps, block_config)) + + all_strats.extend(ifm_input_search_resuls) + + rewrite_list = [] + sram_used = base_sram_used + + metrics = npu_performance.performance_metrics_for_pass( + self.arch, + ps, + block_config, + rewrite_list=rewrite_list, + force_outputs_to_fast_storage=force_outputs_to_fast_storage, + ) + + res.extend( + self.append_sram_pass_block_config_performance_metrics_rewrite_list( + sram_used, ps, block_config, metrics, rewrite_list, all_strats + ) + ) + + self.n_combinations_searched += len(res) + res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True) + return res + + def search_ifm_streaming_partial(self, ps, block_config): + if ps.placement != PassPlacement.Npu: + return ABORT_SEARCH + + if len(ps.inputs) < 1: + return ABORT_SEARCH + + ifm_tensor = ps.ifm_tensor + + if ifm_tensor is None: + return ABORT_SEARCH + if ifm_tensor.purpose != TensorPurpose.FeatureMap: + return ABORT_SEARCH + if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4: + return ABORT_SEARCH + + pred_pass_list = [] + for pred_candidate in ps.dag_predecessors: + if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor: + # we found a predecessor that produces this IFM tensor + if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps: + # and it only has one successor, namely us + if pred_candidate.placement == PassPlacement.Npu: + if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks: + # and it is on the Npu and fusable - it's a candidate + pred_pass_list.append(pred_candidate) + + if not pred_pass_list: + return ABORT_SEARCH + + all_candidates = [] + for pred_pass in pred_pass_list: + # recurse into the next pass + ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.feature_maps_not_in_fast_storage) + + strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data) + for strat_opt in strat_data: + + pred_pass_block_config = strat_opt[0].block_configs[-1] + rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes( + self.arch, pred_pass, pred_pass_block_config, ps, block_config + ) + if rolling_buffer_dims is None: + continue # this does not pack properly, skip it. + + sram_used = 0 + for tens in ps.inputs: + if tens != ifm_tensor: + if tens.mem_area == self.mem_area: + sram_used += tens.storage_size() + + rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims + + rewrite_list = [ + ( + SchedulerRewrite.ChangeTensorSubPurpose, + ifm_tensor, + TensorSubPurpose.RollingBufferY, + rolling_buffer_y, + None, + ps, + ) + ] + sram_used += ifm_tensor.storage_size_for_sub_purpose( + TensorSubPurpose.RollingBufferY, rolling_buffer_y, None + ) + + all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt])) + + self.n_combinations_searched += len(all_candidates) + return all_candidates + + def get_block_configs(self, ps): + if ps.placement != PassPlacement.Npu: + return [(1, 1, 1, 1)] # default + + block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps) + + # Take a limited number of the largest blocks + if self.arch.block_config_limit > 0: + # Sort by block area, followed by depth + block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True) + bound = min(len(block_configs), self.arch.block_config_limit) + # We take 'n' from the fat end of the list, and 'n' from the thin end of the list. + tmp = block_configs[:bound] + tmp.extend(block_configs[max(bound, len(block_configs) - bound) :]) + block_configs = tmp + + return block_configs + + def search_ifm_streaming_input(self, ps): + sram_used = 0 + for tens in ps.inputs: + if tens.mem_area == self.mem_area: + sram_used += tens.storage_size() + + return self.append_sram(sram_used, self.search_predecessors(ps)) + + def search_weight_streaming_output(self, ps): + strat_data = self.search_weight_streaming_body(ps) + + sram_used = self.non_local_mem_usage[ps.time] + for tens in ps.outputs: + if tens.mem_area == self.mem_area: + sram_used += tens.storage_size() + + return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data) + + @lru_cache(maxsize=None) + def search_weight_streaming_body(self, ps): + + strat_data = self.search_weight_streaming_input(ps) + + res = [] + + all_block_configs = self.get_block_configs(ps) + + for block_config in all_block_configs: + + sram_used = 0 + rewrite_list = [] + + for tens in ps.intermediates: + if tens.mem_area == self.mem_area: + if tens.purpose == TensorPurpose.Weights: + sram_used += tens.storage_size_for_sub_purpose( + TensorSubPurpose.DoubleBuffer, block_config[3] + ) + rewrite_list.append( + ( + SchedulerRewrite.ChangeTensorSubPurpose, + tens, + TensorSubPurpose.DoubleBuffer, + block_config[3], + None, + ps, + ) + ) + else: + sram_used += tens.storage_size() + + metrics = npu_performance.performance_metrics_for_pass( + self.arch, ps, block_config, rewrite_list=rewrite_list + ) + + res.extend( + self.append_sram_pass_block_config_performance_metrics_rewrite_list( + sram_used, ps, block_config, metrics, rewrite_list, strat_data + ) + ) + + self.n_combinations_searched += len(res) + res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True) + return res + + def search_weight_streaming_input(self, ps): + sram_used = 0 + for tens in ps.inputs: + if tens.mem_area == self.mem_area: + sram_used += tens.storage_size() + + return self.append_sram(sram_used, self.search_predecessors(ps)) + + def apply_result(self, strat_set, arch): + pass_to_cascaded_pass = dict() + for _, strat in strat_set.strats.items(): + # rewrite the tensors that need this first. e.g. make rolling buffers + inputs = [] + intermediates = [] + outputs = [] + + for ps in strat.passes: + inputs += ps.inputs + intermediates += ps.intermediates + outputs += ps.outputs + + for tens in set(inputs) & set(outputs): + # tensors that are in both sets are intermediates + + # find pass with input/output tensor, and check if they are both placed on NPU + input_placement = None + output_placement = None + for ps in strat.passes: + if tens in ps.inputs: + input_placement = ps.placement + if tens in ps.outputs: + output_placement = ps.placement + if input_placement == output_placement == PassPlacement.Npu: + tens.set_format(TensorFormat.NHCWB16, arch) + + intermediates.append(tens) + inputs.remove(tens) + outputs.remove(tens) + + for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list: + if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose: + tens.mem_area = self.arch.fast_storage_mem_area + tens.set_new_sub_purpose(sub_purpose, param_a, param_b) + else: + assert 0, "unknown rewrite_op " + str(rewrite_op) + + is_element_wise = True + for ps in strat.passes: + assert ps.placement == strat.passes[0].placement + if not ps.is_element_wise: + is_element_wise = False + break + + cascaded_pass = CascadedPass( + strat.passes[0].name, + strat.strat, + inputs, + intermediates, + outputs, + strat.passes, + strat.passes[0].placement, + is_element_wise, + ) + assert strat.sram_used >= 0 + cascaded_pass.sram_used = strat.sram_used + + for idx, ps in enumerate(strat.passes): + assert ps not in pass_to_cascaded_pass + pass_to_cascaded_pass[ps] = cascaded_pass + ps.cascade = cascaded_pass + ps.block_config = strat.block_configs[idx] + + if ps.placement == PassPlacement.Npu: + ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config( + self.arch, ps, ps.block_config + ) + assert ps.shared_buffer is not None + + for op in ps.ops: + subgraph = op.attrs.get("subgraph") + if subgraph: + subgraph.base_sram_used = cascaded_pass.sram_used + + # all passes should have a cascaded pass now + if len(pass_to_cascaded_pass) != len(self.sg.passes): + print( + "mismatch: we have %d passes, but only %d have cascaded passes associated" + % (len(self.sg.passes), len(pass_to_cascaded_pass)) + ) + for ps in self.sg.passes: + if not ps in pass_to_cascaded_pass: + print("%3d pass missing cascaded pass %s" % (ps.time, ps)) + + assert len(pass_to_cascaded_pass) == len(self.sg.passes) + # we have all the passes, but we need to put them in order and build predecessor/successor links. + + visit_pass_set = set() + cascaded_passes = [] + + def visit_pass(ps): + if ps in visit_pass_set: + return + visit_pass_set.add(ps) + + cps = ps.cascade + dont_traverse = set(cps.passes) + + for ps in cps.passes: + for pred in ps.predecessors: + if pred in dont_traverse: + continue + visit_pass(pred) + + cascaded_passes.append(cps) + + starting_passes = [ps for ps in self.sg.passes if not ps.successors] + for ps in starting_passes: + visit_pass(ps) + + # reorder so startup init cascaded passes come first + def is_startup_cascaded_pass(cps): + if not cps.passes: + return False + return cps.placement == PassPlacement.StartupInit + + cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [ + cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps) + ] + + self.sg.cascaded_passes = cascaded_passes + self.sg.build_cascaded_pass_links() + + +def schedule_passes(nng, arch, options: SchedulerOptions): + + for sg in nng.subgraphs: + sg.base_sram_used = 0 + + for sg in nng.subgraphs: + # re-entering the same nodes from different contexts requires us to + # build a simplified directed acyclic (DAG) version of the graph to + # use for traversal, rather than using a visit dictionary. this avoids + # recursing infinitely due to loops. + sg.build_pass_dag_predecessors() + + dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options) + + strat_set = dps.search() + + dps.apply_result(strat_set, arch) + + if options.verbose_schedule: + sg.print_cascaded_passes() -- cgit v1.2.1