From d8339a75c9b655c0507e34238078fdad068b4023 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Thu, 27 May 2021 18:49:40 +0100 Subject: MLBEDSW-4034: New Scheduler Size or Performance Optimisation - Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b Signed-off-by: Tim Hall Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4 --- ethosu/vela/scheduler.py | 2013 ++++++++++++++++++++++------------------------ 1 file changed, 958 insertions(+), 1055 deletions(-) (limited to 'ethosu/vela/scheduler.py') diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 65d3313b..00a4dfc7 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -13,1156 +13,1059 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# # Description: -# The scheduler costs various strategies for scheduling the network in order to select the block configuration. +# The scheduler creates and searches for an optimal plan for the network, selecting block configurations and +# subdivisions for the Operators import copy -import enum -from functools import lru_cache - -import numpy as np +from enum import auto +from enum import IntEnum +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple from . import live_range from . import npu_performance -from . import stats_writer +from . import tensor_allocation +from . import weight_compressor +from .architecture_allocator import ArchitectureBlockConfig +from .architecture_allocator import find_block_config +from .architecture_allocator import get_ifm_area_required +from .architecture_allocator import to_upscale +from .architecture_features import ArchitectureFeatures +from .architecture_features import Block +from .cascade_builder import CascadeBuilder +from .cascade_builder import CascadeInfo from .data_type import DataType -from .high_level_command_stream_generator import calc_allowed_ofm_ifm_overlap_for_pass_list from .nn_graph import CascadedPass +from .nn_graph import Graph +from .nn_graph import Pass from .nn_graph import PassPlacement -from .nn_graph import SchedulerRewrite from .nn_graph import SchedulingStrategy -from .npu_performance import make_bandwidth_array -from .npu_performance import make_cycles_array -from .npu_performance import make_metrics_arrays -from .npu_performance import PassCycles +from .nn_graph import Subgraph +from .numeric_util import round_down +from .numeric_util import round_up from .operation import NpuBlockType from .operation import Op -from .operation import Operation -from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer -from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config +from .shape4d import Shape4D from .tensor import MemArea from .tensor import MemType +from .tensor import Tensor from .tensor import TensorFormat from .tensor import TensorPurpose from .tensor import TensorSubPurpose -class ParetoMetric(enum.Enum): - BwCycMem = 1 - BwCycMemBlkH = 2 +def shape_for_format(shape: Shape4D, tensor_format: TensorFormat) -> Shape4D: + if tensor_format == TensorFormat.NHCWB16: + return shape.with_depth(round_up(shape.depth, 16)) - def __str__(self): - return self.name + return shape -class SchedulerOptions: - def __init__( - self, - use_cascading=True, - verbose_schedule=False, - verbose_pareto_frontier_schedules=False, - use_ifm_streaming=True, - pareto_metric=ParetoMetric.BwCycMem, - use_nhcwb16_between_cascaded_passes=True, - cache_bias_scale_tensor=True, - ): - self.use_cascading = use_cascading - self.verbose_schedule = verbose_schedule - self.verbose_pareto_frontier_schedules = verbose_pareto_frontier_schedules - self.use_ifm_streaming = use_ifm_streaming - self.pareto_metric = pareto_metric - self.use_nhcwb16_between_cascaded_passes = use_nhcwb16_between_cascaded_passes - self.cache_bias_scale_tensor = cache_bias_scale_tensor +class OptimizationStrategy(IntEnum): + """Enum defining the different optimization strategies for the Scheduler""" - def __str__(self): - return type(self).__name__ + ": " + str(self.__dict__) + Size = auto() + Performance = auto() - __repr__ = __str__ + def __str__(self): + return self.name -class Strategy: - __slots__ = "strat", "param", "passes", "block_configs", "rewrite_list", "bws", "macs", "cycles", "sram_used" +class SchedulerOpInfo: + """Contains metadata about a SchedulerOperation that is unique to one Schedule""" - def __init__(self, strat, param, passes, block_configs, rewrite_list, bws, macs, cycles, sram_used): - self.strat = strat - self.param = param - self.passes = passes - self.block_configs = block_configs - self.rewrite_list = ( - rewrite_list # list of (SchedulerRewrite, Tensor, new sub purpose, purpose param a, purpose param b, pass) - ) - self.bws = bws - self.macs = macs - self.cycles = cycles - self.sram_used = sram_used - - def __eq__(self, other): - if self.strat != other.strat: - return False - if self.param != other.param: - return False - if self.block_configs != other.block_configs: - return False - if self.passes != other.passes: - return False - if (self.bws != other.bws).any(): - return False - if self.macs != other.macs: - return False - if (self.cycles != other.cycles).any(): - return False - if self.sram_used != other.sram_used: - return False - return True - - def empty(self): - return not self.passes - - def key(self): - return self.passes[-1] - - def clone(self): - return Strategy( - self.strat, - self.param, - self.passes, - self.block_configs, - self.rewrite_list, - self.bws, - self.macs, - self.cycles, - self.sram_used, - ) + def __init__( + self, + block_config: ArchitectureBlockConfig, + weights_size: int, + stripe_input: Shape4D, + stripe_input2: Optional[Shape4D], + stripe: Shape4D, + ): + self.block_config = block_config + self.weights_size = weights_size + self.stripe_input = stripe_input + self.stripe_input2 = stripe_input2 + self.stripe = stripe + self.cascade = 0 # Assigned by CascadeBuilder. 0 means not part of a cascade + self.time_index = None # Set by update_op_memory_snapshot + self.ofm_depth_slices: List[int] = [0, stripe.depth] + self.npu_weights_tensor = None + self.buffered_weight_tensor = None + self.cycles = None + self.slack_buffering_cycles = 0 + self.slack_buffering_memory = 0 + self.full_weight_transfer_cycles = 0 + + def copy(self): + res = SchedulerOpInfo(self.block_config, self.weights_size, self.stripe_input, self.stripe_input2, self.stripe,) + res.cascade = self.cascade + return res def __str__(self): - return "" % ( - self.strat, - self.passes, - self.rewrite_list, - self.bws, - self.macs, - self.cycles, - self.sram_used, + res = f"\t\tBlock Config = {self.block_config}\n" + res += f"\t\tOFM Block = {self.block_config.ofm_block}\n" + res += f"\t\tIFM Stripe = {self.stripe_input}\n" + res += f"\t\tIFM2 Stripe = {self.stripe_input2}\n" + res += f"\t\tOFM Stripe = {self.stripe}\n" + res += f"\t\tEncoded Weights = {self.npu_weights_tensor and len(self.npu_weights_tensor.buffer)} bytes\n" + res += ( + f"\t\tWeight buffer = {self.buffered_weight_tensor and self.buffered_weight_tensor.storage_size()} bytes\n" ) + res += f"\t\tDepth slices = {self.ofm_depth_slices}\n" + res += f"\t\tAssigned Cascade = {self.cascade}" + return res - __repr__ = __str__ +class SchedulerOptions: + """Contains options for the Scheduler""" -class StrategySet: - __slots__ = "strats", "bws", "macs", "cycles", "max_sram_used", "total_sram_used" - - def __init__(self, strats=None): - if strats is None: - strats = dict() - self.strats = strats # final pass in packed pass -> Strategy - self.bws, self.macs, self.cycles = make_metrics_arrays() - self.max_sram_used = 0 - self.total_sram_used = 0 - - def update_statistics(self): - self.bws = make_bandwidth_array() - self.max_sram_used = 0 - for ps, strat in self.strats.items(): - self.bws += strat.bws - self.macs += strat.macs - self.cycles += strat.cycles - self.max_sram_used = max(self.max_sram_used, strat.sram_used) - self.total_sram_used += strat.sram_used - - def clone_add_strategy(self, new_strat): - key = new_strat.key() - if key in self.strats: - assert new_strat == self.strats[key] - return self - else: - new_strats = dict(self.strats) - new_strats[key] = new_strat - new_set = StrategySet(new_strats) - new_set.bws = self.bws + new_strat.bws - new_set.macs = self.macs + new_strat.macs - new_set.cycles = self.cycles + new_strat.cycles - new_set.max_sram_used = max(self.max_sram_used, new_strat.sram_used) - new_set.total_sram_used = self.total_sram_used + new_strat.sram_used - return new_set - - def __eq__(self, other): - if (self.bws != other.bws).any(): - return False - if self.macs != other.macs: - return False - if (self.cycles != other.cycles).any(): - return False - if self.max_sram_used != other.max_sram_used: - return False - if self.total_sram_used != other.total_sram_used: - return False - if self.strats != other.strats: - return False - return True + def __init__( + self, optimization_strategy, sram_target, verbose_schedule, + ): + self.optimization_strategy = optimization_strategy + self.optimization_sram_limit = sram_target + self.verbose_schedule = verbose_schedule - def __str__(self): - return "" % ( - self.max_sram_used, - list(ps.name for ps in self.strats), - ) + def __str__(self) -> str: + return f"{type(self).__name__}: {str(self.__dict__)}" __repr__ = __str__ -empty_strategy = Strategy( - SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), 0, make_cycles_array(), 0 -) -INFINITY = 1e30 +class SchedulerTensor: + def __init__(self, shape, dt, mem_area, _format): + self.dtype = dt + self.mem_area = mem_area + self.shape = shape + self.format = _format + self.connection = None -ABORT_SEARCH = [] +class SchedulerOperation: + """Scheduler internal representation of 'Operation' + This class can be seen as a node within the Scheduler Graph representation + """ -def flatten_list_of_lists(lstlst): - lst = [] - for v in lstlst: - lst.extend(v) - return lst - - -class DynamicProgrammingScheduler: - def __init__(self, nng, sg, arch, sram_limit, options: SchedulerOptions): - self.nng = nng - self.sg = sg + def __init__(self, ps: Pass, arch: ArchitectureFeatures, nng: Graph): self.arch = arch - self.sram_limit = sram_limit - self.options = copy.copy(options) - self.use_cascading = options.use_cascading - - if self.arch.feature_map_storage_mem_area != MemArea.Sram: - self.use_ifm_ofm_overlap = False # force off IFM/OFM overlap if IFMs and OFMs are not in the SRAM - else: - self.use_ifm_ofm_overlap = True - - self.verbose_schedule = options.verbose_schedule - self.verbose_pareto_frontier_schedules = options.verbose_pareto_frontier_schedules - self.mem_area = MemArea.Sram - - self.bandwidth_weights = arch.bandwidth_weights - self.cycles_weight = arch.cycles_weight - self.max_sram_used_weight = arch.max_sram_used_weight - - self.n_combinations_searched = 0 - - self.pareto_max_candidates = 16 - - self.ifm_stream_npu_blocks = set( - (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,) - ) - - num_pareto_metrics = 4 - view_values = ",".join(["d"] * num_pareto_metrics) - order_values = ["f%d" % (idx,) for idx in range(num_pareto_metrics)] - - def pareto_metric(self, candidate): - strat, strat_set = candidate - total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total] - bws = strat.bws + strat_set.bws - last_block_height = 0 - if self.options.pareto_metric == ParetoMetric.BwCycMemBlkH and len(strat.block_configs) > 0: - last_block_height = strat.block_configs[-1][0] - - return ( - np.tensordot(bws, self.bandwidth_weights, axes=3) + total_cycles * self.cycles_weight, - strat_set.max_sram_used, - strat.sram_used, - last_block_height, - ) - - def filter_pareto_frontier(self, candidates, remove_equally_good_candidates): - - candidates = [cand for cand in candidates if max(cand[0].sram_used, cand[1].max_sram_used) <= self.sram_limit] - - if len(candidates) <= 1: - return candidates - assert remove_equally_good_candidates - pareto_vals = np.zeros((len(candidates), DynamicProgrammingScheduler.num_pareto_metrics)) - ids = np.arange(len(candidates), dtype=np.int32) - for idx, cand in enumerate(candidates): - pareto_vals[idx] = self.pareto_metric(cand) - - sort_order = np.argsort( - pareto_vals.view(DynamicProgrammingScheduler.view_values), - order=DynamicProgrammingScheduler.order_values, - axis=0, - kind="stable", - ).flatten() - pareto_vals = pareto_vals[sort_order] - ids = ids[sort_order] - - pareto_frontier = [] - while len(ids) > 0: - pareto_frontier.append(candidates[ids[0]]) - not_dominated_by_first = (pareto_vals < pareto_vals[0]).any(axis=1) - ids = ids[not_dominated_by_first] - pareto_vals = pareto_vals[not_dominated_by_first] - - if len(pareto_frontier) > self.pareto_max_candidates: - pareto_frontier = self.sort_by_candidate_metric(pareto_frontier) - pareto_frontier = pareto_frontier[: self.pareto_max_candidates] - - return pareto_frontier - - def candidate_metric(self, candidate): - strat, strat_set = candidate - max_sram_used = max(strat_set.max_sram_used, strat.sram_used) - bws = strat.bws + strat_set.bws - total_cycles = strat.cycles[PassCycles.Total] + strat_set.cycles[PassCycles.Total] - - return ( - max_sram_used * self.max_sram_used_weight - + np.tensordot(bws, self.bandwidth_weights, axes=3) - + total_cycles * self.cycles_weight + self.parent_ps = ps + self.parent_op = ps.primary_op + self.name = ps.primary_op.name + self.op_type = ps.primary_op.type + self.activation = ps.primary_op.activation + self.kernel = ps.primary_op.kernel + self.resampling_mode = ps.primary_op.ifm.resampling_mode + self.uses_scalar = ps.primary_op.ifm2 is not None and ( + ps.primary_op.ifm.shape == [] or ps.primary_op.ifm2.shape == [] ) + self.ifm_ublock = arch.ifm_ublock - def sort_by_candidate_metric(self, candidate_list): - sorted_list = list(sorted(candidate_list, key=self.candidate_metric)) - return sorted_list - - def best_candidate(self, candidate_list): - if len(candidate_list) == 0: - return ABORT_SEARCH - if len(candidate_list) == 1: - return candidate_list[0] - sorted_list = self.sort_by_candidate_metric(candidate_list) - return sorted_list[0] - - def graduate_strat(self, strat_type, sram_used, old_strat_data): - res = [] - for old_strat, old_strat_set in old_strat_data: - if old_strat.sram_used + sram_used > self.sram_limit: - continue # This strategy is bad, drop it - if old_strat_set.max_sram_used > self.sram_limit: - continue # This strategy is bad, drop it - assert old_strat.strat == SchedulingStrategy.Unknown - - new_strat = old_strat.clone() - new_strat.strat = strat_type - new_strat.sram_used = old_strat.sram_used + sram_used - - if self.use_ifm_ofm_overlap: - overlap = calc_allowed_ofm_ifm_overlap_for_pass_list( - new_strat.strat, new_strat.passes, new_strat.block_configs - ) - new_strat.sram_used -= overlap - - new_strat_set = old_strat_set.clone_add_strategy(new_strat) - res.append((empty_strategy, new_strat_set)) - return self.filter_pareto_frontier(res, remove_equally_good_candidates=True) + self.ifm = SchedulerTensor(ps.ifm_shapes[0], ps.ifm_tensor.dtype, ps.ifm_tensor.mem_area, ps.ifm_tensor.format,) - def append_sram(self, sram_used, old_strat_data): - res = [] - for old_strat, strat_set in old_strat_data: - assert old_strat.strat == SchedulingStrategy.Unknown - assert old_strat.sram_used == 0 - new_strat = old_strat.clone() - new_strat.sram_used = old_strat.sram_used + sram_used - - res.append((new_strat, strat_set)) - return res - - def append_sram_block_config_performance_metrics(self, sram_used, block_config, metrics, old_strat_data): - res = [] - for old_strat, strat_set in old_strat_data: - assert old_strat.strat == SchedulingStrategy.Unknown - new_strat = old_strat.clone() - bws, macs, cycles = metrics[:3] - - new_strat.sram_used = old_strat.sram_used + sram_used - new_strat.block_configs = old_strat.block_configs + [block_config] - new_strat.bws = old_strat.bws + bws - new_strat.macs = old_strat.macs + macs - new_strat.cycles = old_strat.cycles + cycles - new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass( - self.arch, new_strat.bws, new_strat.macs, new_strat.cycles + self.ifm2 = None + if ps.ifm2_tensor: + self.ifm2 = SchedulerTensor( + ps.ifm_shapes[1], ps.ifm2_tensor.dtype, ps.ifm2_tensor.mem_area, ps.ifm2_tensor.format, ) - res.append((new_strat, strat_set)) - return res + self.ofm = SchedulerTensor(ps.ofm_shapes[0], ps.ofm_tensor.dtype, ps.ofm_tensor.mem_area, ps.ofm_tensor.format,) - def append_sram_pass_block_config_performance_metrics_rewrite_list( - self, sram_used, new_pass, block_config, metrics, rewrite_list, old_strat_data - ): - res = [] - for old_strat, strat_set in old_strat_data: - assert old_strat.strat == SchedulingStrategy.Unknown - new_strat = old_strat.clone() - bws, macs, cycles = metrics[:3] - new_strat.sram_used = old_strat.sram_used + sram_used - new_strat.block_configs = old_strat.block_configs + [block_config] - new_strat.bws = old_strat.bws + bws - new_strat.macs = old_strat.macs + macs - new_strat.cycles = old_strat.cycles + cycles - new_strat.passes = old_strat.passes + [new_pass] - new_strat.bws, new_strat.macs, new_strat.cycles = npu_performance.collate_stats_for_cascaded_pass( - self.arch, new_strat.bws, new_strat.macs, new_strat.cycles - ) - new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list - res.append((new_strat, strat_set)) - return res + # Input volume width and height required to produce the smallest possible stripe + self.min_stripe_input_w, self.min_stripe_input_h = self._calculate_min_stripe_input() - def append_sram_rewrite_list(self, sram_used, rewrite_list, old_strat_data): - res = [] - for old_strat, strat_set in old_strat_data: - assert old_strat.strat == SchedulingStrategy.Unknown - new_strat = old_strat.clone() - new_strat.sram_used = old_strat.sram_used + sram_used - new_strat.rewrite_list = old_strat.rewrite_list + rewrite_list - res.append((new_strat, strat_set)) - return res + # Flags that marks whether this SchedulerOperation requires full IFM/OFM + self.requires_full_ifm = False + self.requires_full_ifm2 = False + self.requires_full_ofm = False - def pass_to_strat(self, strat_data): - res = {} - for strat in strat_data[1].strats.values(): - for ps in strat.passes: - res[ps] = strat - return res + self.index = 0 - def compatible_strats(self, a, b): - intersection = a.keys() & b.keys() - for k in intersection: - if a[k] != b[k]: - return False - return True - - def collate_strats_for_passes(self, all_passes): - if len(all_passes) == 0: - return [(empty_strategy, StrategySet(dict()))] - if len(all_passes) == 1: - return all_passes[0] # save some space in the common case - all_strands = [[self.pass_to_strat(strat_data) for strat_data in strand] for strand in all_passes] - prev_combos = [dict()] - for j, strand in enumerate(all_strands): - new_combos = [] - for i, alt in enumerate(strand): - for prev in prev_combos: - if self.compatible_strats(prev, alt): - cmb = dict(prev) - cmb.update(all_passes[j][i][1].strats) - new_combos.append(cmb) - prev_combos = new_combos - - res = [] - for d in prev_combos: - s = StrategySet(d) - s.update_statistics() - res.append((empty_strategy, s)) - return res + def add_ifm_connection(self, conn: "Connection"): + """Add input connection to another SchedulerOperation or Subgraph Input""" + conn.consumers.append(self) + self.ifm.connection = conn - def search_all_but_one_predecessor(self, ps, pred_pass, pred_pass_data): - # get the rest of the predecessors - other_predecessors = [pred for pred in ps.dag_predecessors if pred != pred_pass] - other_predecessor_data = self.search_pass_list(other_predecessors) - - # pred strat data has an incomplete strategy, which we need - # to continue on, whereas the other ones have completed strategies. - # we need to merge these, but keep the incomplete strategy too. - - res = [] - for pred_pass_strat, pred_pass_strat_set in pred_pass_data: - all_strats = [ - [(empty_strategy, pred_pass_strat_set)], # pred strat data but with a dummy empty strategy - other_predecessor_data, # this one is fine to use as-is - ] - collated_strat_data = self.collate_strats_for_passes(all_strats) - strat_data = [(pred_pass_strat, strat_set) for _, strat_set in collated_strat_data] - res.extend(strat_data) - return res + def add_ifm2_connection(self, conn: "Connection"): + """Add input connection to another SchedulerOperation or Subgraph Input""" + if self.ifm2: + conn.consumers.append(self) + self.ifm2.connection = conn + else: + assert False, f"Trying to set an IFM2 Connection to {self} which has no IFM2" + + def add_ofm_connection(self, conn: "Connection"): + """Add output connection to another SchedulerOperation or Subgraph Output""" + conn.producers.append(self) + self.ofm.connection = conn + + def get_dependants(self): + """Returns a list of the Ops that depend on this Operation's OFM""" + return self.ofm.connection.consumers + + def ifm_size_in_bytes(self) -> int: + """Returns size of the IFM in bytes""" + ifm_storage_shape = shape_for_format(self.ifm.shape, self.ifm.format) + return round_up(ifm_storage_shape.elements() * self.ifm.dtype.size_in_bytes(), Tensor.AllocationQuantum) + + def ifm2_size_in_bytes(self) -> int: + """Returns size of the IFM2 in bytes""" + if self.ifm2: + ifm2_storage_shape = shape_for_format(self.ifm2.shape, self.ifm2.format) + return round_up(ifm2_storage_shape.elements() * self.ifm2.dtype.size_in_bytes(), Tensor.AllocationQuantum) + + return 0 + + def ofm_size_in_bytes(self) -> int: + """Returns size of the OFM in bytes""" + ofm_storage_shape = shape_for_format(self.ofm.shape, self.ofm.format) + return round_up(ofm_storage_shape.elements() * self.ofm.dtype.size_in_bytes(), Tensor.AllocationQuantum) + + def create_scheduler_info(self, nng: Graph, stripe: Shape4D) -> SchedulerOpInfo: + """Returns schedule info about this SchedulerOperation based on how many ofm elements it should produce""" + ifm_shape = self.ifm.shape + ifm2_shape = self.ifm2 and self.ifm2.shape + ofm_shape = stripe + + if ofm_shape != self.ofm.shape: + # Striped Op - Need to calculate stripe input volume + stripe_input_w, stripe_input_h = self._get_stripe_input_requirement(stripe) + # Ensure stripe input volume is within the full IFM volume + stripe_input_h = min(stripe_input_h, self.ifm.shape.height) + stripe_input_w = min(stripe_input_w, self.ifm.shape.width) + ifm_shape = ifm_shape.with_hw(stripe_input_h, stripe_input_w) + + if self.ifm2: + stripe_input2_h = min(stripe_input_h, self.ifm2.shape.height) + stripe_input2_w = min(stripe_input_w, self.ifm2.shape.width) + ifm2_shape = ifm2_shape.with_hw(stripe_input2_h, stripe_input2_w) + + block_config = self._get_block_config(ifm_shape, ifm2_shape, self.uses_scalar, ofm_shape) + + scheduler_op_info = SchedulerOpInfo(block_config, 0, ifm_shape, ifm2_shape, ofm_shape) + if self.parent_op.weights: + # Default full-depth weight encoding with no buffering + scheduler_op_info.npu_weights_tensor = weight_compressor.encode_weight_and_scale_tensor( + self.arch, + self.parent_op, + self.parent_op.weights, + self.parent_op.bias, + self.kernel, + block_config, + [0, self.ofm.shape.depth], + ) - def calc_non_local_mem_usage(self): - ignore_subgraph_input_output_tensors = self.sg.placement == PassPlacement.Cpu - range_set = live_range.extract_live_ranges_from_passes( - self.sg, self.mem_area, ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors, + self.parent_ps.block_config = block_config.old_style_representation() + return scheduler_op_info + + def _get_stripe_input_requirement(self, stripe_shape: Shape4D) -> Tuple[int, int]: + """Returns the amount of IFM required to produce the stripe with shape:'stripe_shape'""" + ofm_shape_to_produce = Block.from_shape(stripe_shape.as_list()) + + return get_ifm_area_required(ofm_shape_to_produce, self.kernel, to_upscale(self.resampling_mode)) + + def _calculate_min_stripe_input(self) -> Shape4D: + # Calculate the input volume required height and width for the smallest possible stripe (h,w = 1,1) + min_stripe = self.ofm.shape.with_hw(1, 1) + return self._get_stripe_input_requirement(min_stripe) + + def _get_block_config( + self, ifm_shape: Shape4D, ifm2_shape: Optional[Shape4D], uses_scalar: bool, ofm_shape: Shape4D + ) -> ArchitectureBlockConfig: + # Returns a block config and SHRAM layout + lut_banks = 2 if self.parent_op.activation_lut else 0 + return find_block_config( + self.arch, + self.op_type.npu_block_type, + ofm_shape, + ifm_shape, + ifm2_shape, + uses_scalar, + self.ifm.dtype.size_in_bits(), + self.kernel, + lut_banks, + self.parent_op.has_scaling(), + self.resampling_mode, ) - range_dict = range_set.ranges - - # find which ranges overlap passes but aren't input/outputs of the passes. - # these won't be counted by the dynamic programming search and must be counted in manually. - end_pos = max(ps.time for ps in self.sg.passes) + 2 - mem_usage = np.zeros(end_pos) + self.sg.base_sram_used - non_local_mem_usage = np.zeros(end_pos, dtype=np.int64) - - for tens, rng in range_dict.items(): - storage_size = tens.storage_size() - assert tens.mem_area == self.mem_area - mem_usage[rng.start_time : rng.end_time] += storage_size - - for ps in self.sg.passes: - local_mem_usage = 0 - for tens in ps.inputs + ps.outputs + ps.intermediates: - if tens.mem_area != self.mem_area: - continue - - local_mem_usage += tens.storage_size() - non_local_mem_usage[ps.time] = mem_usage[ps.time] - local_mem_usage - self.non_local_mem_usage = non_local_mem_usage +class Connection: + """Scheduler internal representation of a Tensor that connects two SchedulerOperations + This class can be seen as an edge within the Scheduler Graph representation + """ - def search(self): - self.calc_non_local_mem_usage() - starting_passes = [ps for ps in self.sg.passes if not ps.successors] - strat_data = self.search_pass_list(starting_passes) + def __init__(self, tensor: Tensor): + self.parent_tens = tensor - _, best_set = self.best_candidate(strat_data) + # SchedulerOperation relationships + self.producers: List[SchedulerOperation] = [] + self.consumers: List[SchedulerOperation] = [] - if self.verbose_pareto_frontier_schedules: - print( - "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier" - % (self.n_combinations_searched, len(strat_data)) - ) - for idx, (_, strat_set) in enumerate(strat_data): - extra = "" - if strat_set == best_set: - extra = "(Best candidate)" - print("Candidate", idx, extra) - memory_used = {MemArea.Sram: strat_set.max_sram_used} - stats_writer.print_performance_metrics_for_strat( - self.arch, - "", - strat_set.cycles, - strat_set.macs, - strat_set.bws, - self.nng.batch_size, - memory_used, - len(self.sg.passes), - len(strat_set.strats), - ) - - return best_set - - def search_pass_list(self, pass_list): - all_strats = [] - for ps in pass_list: - strat = self.search_output(ps) - all_strats.append(strat) - strat_data = self.collate_strats_for_passes(all_strats) - for strd in strat_data: - for ps in pass_list: - assert ps in strd[1].strats # should have strategies for everything we asked to search - return strat_data - - def search_predecessors(self, ps): + def __str__(self): + return f"" - # protect against graphs with loops. collate_strats_for_passes will sort this out later so that - # we have strats for all passes + __repr__ = __str__ - pass_list = ps.dag_predecessors - strat_data = self.search_pass_list(pass_list) - return strat_data +class Schedule: + """Class that contains a solution of how to schedule an NPU subgraph and its cost""" - @lru_cache(maxsize=None) - def search_output(self, ps): + def __init__(self, sg: Subgraph, label: str): + self.sg = sg + self.label = label + self.cost_map: Dict[SchedulerOperation, SchedulerOpInfo] = {} + self.cascades: Dict[int, CascadeInfo] = {} + self.fast_storage_peak_usage = 0 + self.memory_snapshot = None - assert ps in self.sg.passes - candidate_list = [] + @property + def name(self): + return f"{self.sg.name}_{self.label}" - candidate_list.extend(self.search_weight_streaming_output(ps)) - if self.options.use_ifm_streaming: - candidate_list.extend(self.search_ifm_streaming_output(ps)) +class Scheduler: + """Main class of the Vela Scheduling""" - best = self.filter_pareto_frontier(candidate_list, remove_equally_good_candidates=True) + def __init__(self, nng: Graph, sg: Subgraph, arch: ArchitectureFeatures, options: SchedulerOptions): + self.nng = nng + self.sg = sg + self.arch = arch + self.sched_ops: List(SchedulerOperation) = [] + self.max_schedule = None + self.scheduler_options = options + + def create_scheduler_representation(self, arch: ArchitectureFeatures): + """Creates a Scheduler Graph representation""" + # Temporary dict for creating connections between the Operations + connections: Dict[Tensor, Connection] = {} + # Memory required for the largest FeatureMap that has to be full + min_memory_req = 0 + for ps in self.sg.passes: + if ps.primary_op: + # Set tensor format to NHCWB16 for output FeatureMaps, if possible + for output in ps.outputs: + if output.purpose != TensorPurpose.FeatureMap: + continue + if not output.needs_linear_format: + output.set_format(TensorFormat.NHCWB16, arch) - if not best: - print( - "Warning: Dynamic search programming algorithm failed for pass %s, invoking fallback strategy" - % (ps.name,) - ) - return self.search_predecessors(ps) - - return best - - def search_ifm_streaming_output(self, ps): - if ps.placement != PassPlacement.Npu: - return ABORT_SEARCH - if ps.npu_block_type not in self.ifm_stream_npu_blocks: - return ABORT_SEARCH - strat_data = self.search_ifm_streaming_body(ps, False) - - sram_used = self.non_local_mem_usage[ps.time] - for tens in ps.outputs: - if tens.mem_area == self.mem_area: - sram_used += tens.storage_size() - - return self.graduate_strat(SchedulingStrategy.IfmStream, sram_used, strat_data) - - @lru_cache(maxsize=None) - def search_ifm_streaming_body(self, ps, force_outputs_to_fast_storage): - if ps.placement != PassPlacement.Npu: - return ABORT_SEARCH - if ps.npu_block_type not in self.ifm_stream_npu_blocks: - return ABORT_SEARCH - ifm_input_search_resuls = self.search_ifm_streaming_input(ps) - res = [] - - base_sram_used = 0 - for tens in ps.intermediates: - if tens.mem_area == self.mem_area: - if tens.purpose == TensorPurpose.Weights: - base_sram_used = tens.storage_size(self.arch.weight_estimation_scaling) - else: - base_sram_used += tens.storage_size() - - all_block_configs = self.get_block_configs(ps) - for block_config in all_block_configs: - all_strats = [] - - if self.use_cascading: - all_strats.extend(self.search_ifm_streaming_partial(ps, block_config)) - - all_strats.extend(ifm_input_search_resuls) - - rewrite_list = [] - sram_used = base_sram_used - - metrics = npu_performance.performance_metrics_for_pass( - self.arch, - ps, - block_config, - rewrite_list=rewrite_list, - force_outputs_to_fast_storage=force_outputs_to_fast_storage, + # Create SchedulerOperations + op = SchedulerOperation(ps, arch, self.nng) + op.index = len(self.sched_ops) + + # Make connections + if ps.ifm_tensor not in connections: + connections[ps.ifm_tensor] = Connection(ps.ifm_tensor) + if ps.ifm2_tensor and ps.ifm2_tensor not in connections: + connections[ps.ifm2_tensor] = Connection(ps.ifm2_tensor) + if ps.ofm_tensor not in connections: + connections[ps.ofm_tensor] = Connection(ps.ofm_tensor) + + op.add_ifm_connection(connections[ps.ifm_tensor]) + if ps.ifm2_tensor: + op.add_ifm2_connection(connections[ps.ifm2_tensor]) + op.add_ofm_connection(connections[ps.ofm_tensor]) + + # Set requirements on the ifm/ofm buffers + self.sched_ops.append(op) + if ps.ifm_tensor in self.sg.input_tensors: + # This Op consumes a subgraph input + op.requires_full_ifm = True + if ps.ifm2_tensor and ps.ifm2_tensor in self.sg.input_tensors: + # This Op consumes a subgraph input + op.requires_full_ifm2 = True + if ps.ofm_tensor in self.sg.output_tensors: + # This Op produces a subgraph output + op.requires_full_ofm = True + if ps.ifm_tensor.needs_linear_format: + op.requires_full_ifm = True + if ps.ifm2_tensor and ps.ifm2_tensor.needs_linear_format: + op.requires_full_ifm2 = True + if ps.ofm_tensor.needs_linear_format or ps.primary_op.memory_function == Op.ConcatSliceWrite: + op.requires_full_ofm = True + if len(ps.primary_op.outputs) > 1 or len(ps.primary_op.outputs[0].consumer_list) > 1: + # Op has multiple outputs or consumers - requires full OFM + op.requires_full_ofm = True + + # Check memory requirements if this Op requires any full FeatureMaps + op_memory_req = 0 + if op.requires_full_ifm: + op_memory_req += op.ifm_size_in_bytes() + if op.requires_full_ifm2: + op_memory_req += op.ifm2_size_in_bytes() + if op.requires_full_ofm: + op_memory_req += op.ofm_size_in_bytes() + + min_memory_req = max(op_memory_req, min_memory_req) + + # Theoretical minimum required memory - used to guide the cascade building + self.min_memory_req = min_memory_req + + def create_initial_schedule(self) -> Schedule: + """Creates an initial schedule with no cascading or buffering of any kind""" + schedule = Schedule(self.sg, "MAX") + + for op in self.sched_ops: + cost = op.create_scheduler_info(self.nng, op.ofm.shape) + cost.cycles = self.estimate_op_performance(op, cost.block_config, op.ofm.shape.depth) + schedule.cost_map[op] = cost + + return schedule + + def update_op_memory_snapshot(self, schedule: Schedule): + memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))] + + # Collect live ranges from tensors + lr_graph = live_range.LiveRangeGraph() + for mem_area, mem_type_set in memories_list: + live_range.extract_live_ranges_from_cascaded_passes( + self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum, ) - res.extend( - self.append_sram_pass_block_config_performance_metrics_rewrite_list( - sram_used, ps, block_config, metrics, rewrite_list, all_strats - ) + # Populate time-array with memory used by live ranges + temporal_usage = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area) + schedule.memory_snapshot = temporal_usage + + # Set the peak memory usage + schedule.fast_storage_peak_usage = max(temporal_usage, default=0) + + def estimate_op_performance(self, op: SchedulerOperation, block_config, ofm_depth): + query = npu_performance.PerformanceQuery(op.op_type.npu_block_type) + query.ifm_shape = op.ifm.shape + query.ifm_memory_area = op.ifm.mem_area + query.ifm_bits = op.ifm.dtype.size_in_bits() + query.ifm_format = op.ifm.format + query.ifm2_shape = op.ifm2 and op.ifm2.shape + query.ifm2_memory_area = op.ifm2 and op.ifm2.mem_area + query.ifm2_bits = op.ifm2 and op.ifm2.dtype.size_in_bits() + query.ifm2_format = op.ifm2 and op.ifm2.format + query.ofm_shape = op.ofm.shape.with_depth(ofm_depth) + query.ofm_memory_area = op.ofm.mem_area + query.ofm_bits = op.ofm.dtype.size_in_bits() + query.ofm_format = op.ofm.format + if op.parent_op.bias: + query.const_shape = Shape4D(1, 1, 1, op.ofm.shape.depth) + query.const_memory_area = self.arch.fast_storage_mem_area + + query.kernel = op.kernel + query.config = block_config + + return npu_performance.measure_cycle_cost(self.arch, op.op_type, op.activation and op.activation.op_type, query) + + def propose_schedule_buffering(self, ref_schedule: Schedule): + """Create a buffered schedule""" + buffered_schedule = Schedule(self.sg, f"{ref_schedule.label}_BUFFERED") + staging_limit_bytes = self.scheduler_options.optimization_sram_limit + + prev_op = None + for sched_op in self.sched_ops: + if sched_op not in ref_schedule.cost_map: + # sched_op is not part of this sub-schedule - skip + continue + + self.propose_operator_buffering(sched_op, prev_op, buffered_schedule, ref_schedule, staging_limit_bytes) + prev_op = sched_op + + return buffered_schedule + + def propose_operator_buffering( + self, + sched_op: SchedulerOperation, + prev_op: SchedulerOperation, + buffered_schedule: Schedule, + ref_schedule: Schedule, + staging_limit_bytes, + ): + # Mild recursion might mean this Op has already been seen + if sched_op in buffered_schedule.cost_map: + return + + # Take the reference schedule as default costings for this schedule + ref_cost = ref_schedule.cost_map[sched_op] + cost = copy.copy(ref_cost) + cost.slack_buffering_cycles = ref_cost.cycles.op_cycles + memory_snapshot = ref_schedule.memory_snapshot + ref_memory_usage = memory_snapshot[ref_cost.time_index] if ref_cost.time_index < len(memory_snapshot) else 0 + cost.slack_buffering_memory = staging_limit_bytes - ref_memory_usage + buffered_schedule.cost_map[sched_op] = cost + + # Attempt weight buffering on anything with a weights tensor + if sched_op.parent_op.weights: + self.propose_weight_buffering( + sched_op.parent_op.weights, + sched_op.parent_op.bias, + sched_op, + prev_op, + buffered_schedule, + ref_schedule, + cost.slack_buffering_memory, ) - self.n_combinations_searched += len(res) - res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True) - return res + return cost - def avoid_for_cascading(self, pred_candidate): - for op in pred_candidate.ops: + def weights_needs_dma(self, weight_tensor): + if weight_tensor and weight_tensor.mem_type not in (MemType.Scratch, MemType.Scratch_fast): + # Weights are in permanent storage + # Only when permanent storage differs from feature map storage, there is a point moving the data if ( - op.memory_function == Op.ConcatSliceWrite - and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area + weight_tensor.mem_area in (MemArea.Dram, MemArea.OffChipFlash) + and self.arch.permanent_storage_mem_area != self.arch.fast_storage_mem_area ): - # For SRAM spilling, concat op is avoided as predecessor - return True - if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1: - # The op has consumers in other subgraphs return True return False - def search_ifm_streaming_partial(self, ps, block_config): - if ps.placement != PassPlacement.Npu: - return ABORT_SEARCH - - if len(ps.inputs) < 1: - return ABORT_SEARCH - - ifm_tensor = ps.ifm_tensor - - if ifm_tensor is None: - return ABORT_SEARCH - if ifm_tensor.purpose != TensorPurpose.FeatureMap: - return ABORT_SEARCH - if not ifm_tensor.storage_shape or len(ifm_tensor.storage_shape) != 4: - return ABORT_SEARCH - - pred_pass_list = [] - for pred_candidate in ps.dag_predecessors: - if len(pred_candidate.outputs) == 1 and pred_candidate.outputs[0] == ifm_tensor: - # we found a predecessor that produces this IFM tensor - if not ifm_tensor.needs_linear_format: - # and NHCWB16 can be used - if len(pred_candidate.successors) == 1 and pred_candidate.successors[0] == ps: - # and it only has one successor, namely us - if pred_candidate.placement == PassPlacement.Npu: - if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks: - # and it is on the Npu - if not self.avoid_for_cascading(pred_candidate): - # and fusable - it's a candidate - pred_pass_list.append(pred_candidate) - - if not pred_pass_list: - return ABORT_SEARCH - - all_candidates = [] - for pred_pass in pred_pass_list: - # recurse into the next pass - ifm_strat_data = self.search_ifm_streaming_body(pred_pass, self.arch.is_spilling_enabled()) - - strat_data = self.search_all_but_one_predecessor(ps, pred_pass, ifm_strat_data) - for strat_opt in strat_data: - - pred_pass_block_config = strat_opt[0].block_configs[-1] - rolling_buffer_dims = npu_performance.rolling_buffer_dims_from_passes( - self.arch, pred_pass, pred_pass_block_config, ps, block_config - ) - if rolling_buffer_dims is None: - continue # this does not pack properly, skip it. - - sram_used = 0 - for tens in ps.inputs: - if tens != ifm_tensor: - if tens.mem_area == self.mem_area: - sram_used += tens.storage_size() - - rolling_buffer_y, rolling_buffer_x = rolling_buffer_dims - - rewrite_list = [ - ( - SchedulerRewrite.ChangeTensorSubPurpose, - ifm_tensor, - TensorSubPurpose.RollingBufferY, - rolling_buffer_y, - None, - ps, + def propose_weight_buffering( + self, + weight_tensor, + scale_tensor, + sched_op: SchedulerOperation, + prev_op: SchedulerOperation, + buffered_schedule: Schedule, + ref_schedule: Schedule, + buffer_limit_bytes, + ): + cost = buffered_schedule.cost_map[sched_op] + prev_cost = buffered_schedule.cost_map.get(prev_op) + ref_cost = ref_schedule.cost_map[sched_op] + assert cost and ref_cost + + needs_dma = self.weights_needs_dma(weight_tensor) + + ofm_full_depth_slices = [0, ref_cost.stripe.depth] + + # Encode weights for the full depth + full_weights = weight_compressor.encode_weight_and_scale_tensor( + self.arch, + sched_op.parent_op, + weight_tensor, + scale_tensor, + sched_op.kernel, + cost.block_config, + ofm_full_depth_slices, + ) + full_weights_bytes = len(full_weights.buffer) + cost.ofm_depth_slices = ofm_full_depth_slices + + # No buffering required - take all the weights from permanent storage + if sched_op.op_type == Op.FullyConnected or not needs_dma: + cost.npu_weights_tensor = full_weights + return + + encoded_weights = full_weights + + # How many NPU cycles are available under the previously executing + # operator and SRAM unused for performing buffered DMA transfers + slack_cycles = prev_cost.slack_buffering_cycles if prev_cost else 0 + slack_memory = prev_cost.slack_buffering_memory if prev_cost else 0 + + # Force full depth for cascaded Ops + if ref_cost.cascade != 0: + weight_tensor_purpose = TensorSubPurpose.Standard + weight_buffer_size = full_weights_bytes + # Update the memory snapshot to reflect the added size of the weights + ref_schedule.memory_snapshot[ref_cost.time_index] += weight_buffer_size + else: + # Estimate the buffering cycle time for the full set of weights + full_transfer_cycles = npu_performance.measure_mem2mem_cycles( + self.arch, weight_tensor.mem_area, self.arch.fast_storage_mem_area, full_weights_bytes + ) + cost.full_weight_transfer_cycles = full_transfer_cycles + + # Calculate the amount of prebuffering necessary (or what is possible with limited + # double buffer buffer size) + half_buffer_limit = buffer_limit_bytes // 2 + if full_transfer_cycles > slack_cycles: + prebuffer_ratio = slack_cycles / full_transfer_cycles + prebuffer_bytes = min(prebuffer_ratio * full_weights_bytes, half_buffer_limit) + else: + prebuffer_bytes = min(full_weights_bytes, half_buffer_limit) + prebuffer_ratio = prebuffer_bytes / full_weights_bytes + + # Have to split the weights if the initial buffering can't store + # all of the compressed weights + if prebuffer_bytes < full_weights_bytes: + prebuffer_depth = int(ref_cost.stripe.depth * prebuffer_ratio) + + # Round prebuffering down to nearest valid split depth + prebuffer_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth))) + + while True: + buffering_depth = max(cost.block_config.ofm_block.depth, prebuffer_depth) + + # Clamp buffering to the double buffering limit + buffering_bytes = (buffering_depth / ref_cost.stripe.depth) * full_weights_bytes + if buffering_bytes > half_buffer_limit: + buffering_depth = (half_buffer_limit / full_weights_bytes) * ref_cost.stripe.depth + buffering_depth = int(max(16, round_down(prebuffer_depth, ArchitectureFeatures.OFMSplitDepth))) + + # Create list of depth slices + depth_slices = [0] + if prebuffer_depth < ref_cost.stripe.depth: + depth_slices += list(range(prebuffer_depth, ref_cost.stripe.depth, buffering_depth)) + depth_slices.append(ref_cost.stripe.depth) + + # Encode weights based depth slices + cost.ofm_depth_slices = depth_slices + encoded_weights = weight_compressor.encode_weight_and_scale_tensor( + self.arch, + sched_op.parent_op, + weight_tensor, + scale_tensor, + sched_op.kernel, + cost.block_config, + cost.ofm_depth_slices, ) - ] - sram_used += ifm_tensor.storage_size_for_sub_purpose( - self.arch, TensorSubPurpose.RollingBufferY, rolling_buffer_y, None - ) - - all_candidates.extend(self.append_sram_rewrite_list(sram_used, rewrite_list, [strat_opt])) - - self.n_combinations_searched += len(all_candidates) - return all_candidates - - def get_block_configs(self, ps): - if ps.placement != PassPlacement.Npu: - return [(1, 1, 1, 1)] # default - - block_configs = find_block_configs_suitable_for_pass_and_shared_buffer(self.arch, ps) - - # Take a limited number of the largest blocks - if self.arch.block_config_limit > 0: - # Sort by block area, followed by depth - block_configs.sort(key=lambda cfg: (cfg[0] * cfg[1]) << 8 | cfg[3], reverse=True) - bound = min(len(block_configs), self.arch.block_config_limit) - # We take 'n' from the fat end of the list, and 'n' from the thin end of the list. - tmp = block_configs[:bound] - tmp.extend(block_configs[max(bound, len(block_configs) - bound) :]) - block_configs = tmp - - return block_configs - - def search_ifm_streaming_input(self, ps): - sram_used = 0 - for tens in ps.inputs: - if tens.mem_area == self.mem_area: - sram_used += tens.storage_size() - - return self.append_sram(sram_used, self.search_predecessors(ps)) - - def search_weight_streaming_output(self, ps): - strat_data = self.search_weight_streaming_body(ps) - - sram_used = self.non_local_mem_usage[ps.time] - for tens in ps.outputs: - if tens.mem_area == self.mem_area: - sram_used += tens.storage_size() - - return self.graduate_strat(SchedulingStrategy.WeightStream, sram_used, strat_data) - - @lru_cache(maxsize=None) - def search_weight_streaming_body(self, ps): - - strat_data = self.search_weight_streaming_input(ps) - - res = [] - - all_block_configs = self.get_block_configs(ps) - for block_config in all_block_configs: + # Chosen buffering might not fit at all, iterate until it does + # or until the minimum usable slice size is reached + if ( + encoded_weights.max_range_bytes <= half_buffer_limit + or prebuffer_depth == ArchitectureFeatures.OFMSplitDepth + ): + break - sram_used = 0 - rewrite_list = [] + prebuffer_depth = round_up(prebuffer_depth // 2, ArchitectureFeatures.OFMSplitDepth) - for tens in ps.intermediates: - if tens.mem_area == self.mem_area: - if tens.purpose == TensorPurpose.Weights: - sram_used += tens.storage_size_for_sub_purpose( - self.arch, TensorSubPurpose.DoubleBuffer, block_config[3] - ) - rewrite_list.append( - ( - SchedulerRewrite.ChangeTensorSubPurpose, - tens, - TensorSubPurpose.DoubleBuffer, - block_config[3], - None, - ps, - ) - ) - else: - sram_used += tens.storage_size() - - metrics = npu_performance.performance_metrics_for_pass( - self.arch, ps, block_config, rewrite_list=rewrite_list + # Calculate cycles required to run the last op for use as future slack + tail_cycles = self.estimate_op_performance( + sched_op, cost.block_config, depth_slices[-1] - depth_slices[-2] + ) + cost.slack_buffering_cycles = tail_cycles.op_cycles + + # Determine whether the weights need to be double buffered + weight_buffer_size = min(len(encoded_weights.buffer), encoded_weights.max_range_bytes) + + # Only buffer weights if there's still space left for the buffer + if weight_buffer_size <= buffer_limit_bytes: + assert weight_buffer_size % 16 == 0 + # Determine whether to double buffer or single buffer + if (weight_buffer_size * 2 <= buffer_limit_bytes) and (weight_buffer_size < len(encoded_weights.buffer)): + weight_buffer_size = weight_buffer_size * 2 + weight_tensor_purpose = TensorSubPurpose.DoubleBuffer + else: + weight_tensor_purpose = TensorSubPurpose.Standard + + cost.buffered_weight_tensor = Tensor( + [1, 1, 1, weight_buffer_size], DataType.uint8, weight_tensor.name + "_buffer" ) - - res.extend( - self.append_sram_pass_block_config_performance_metrics_rewrite_list( - sram_used, ps, block_config, metrics, rewrite_list, strat_data + cost.buffered_weight_tensor.src_tensor = encoded_weights + cost.buffered_weight_tensor.mem_area = self.arch.fast_storage_mem_area + cost.buffered_weight_tensor.mem_type = MemType.Scratch_fast + cost.buffered_weight_tensor.purpose = TensorPurpose.Weights + cost.buffered_weight_tensor.sub_purpose = weight_tensor_purpose + if ref_cost.cascade == 0: + # Determine if the lifetime can be extended and pre-buffer weights under the previous operation + cost.buffered_weight_tensor.pre_buffer = weight_buffer_size < slack_memory + + cost.slack_buffering_memory -= weight_buffer_size + else: + # Don't slice or buffer - use the whole depth from persistent storage + cost.ofm_depth_slices = ofm_full_depth_slices + encoded_weights = full_weights + + cost.npu_weights_tensor = encoded_weights + + def propose_minimal_schedule(self) -> Schedule: + """Proposes scheduling parameters where every operator is subdivided into the smallest stripe that satisfies the + next operators stride""" + min_schedule = Schedule(self.sg, "MIN") + cost_map = min_schedule.cost_map + + # Keep track of the previous Op - which consumes the current Op's OFM + prev_op = None + for sched_op in reversed(self.sched_ops): + min_stripe_height = prev_op.kernel.stride.y if prev_op else 1 + min_stripe = sched_op.ofm.shape.with_height(min_stripe_height) + + cost = sched_op.create_scheduler_info(self.nng, min_stripe) + cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth) + cost_map[sched_op] = cost + + prev_op = sched_op + + return min_schedule + + def propose_schedule_striping(self, final_stripe: Shape4D, label: str, ref_schedule: Schedule) -> Schedule: + """Proposes new striping for a schedule. The stripe is derived from the ifm requirements of the next Op down""" + ref_cost = ref_schedule.cost_map + + striped_schedule = Schedule(self.sg, label) + stripe = final_stripe + for sched_op in reversed(self.sched_ops): + if sched_op not in ref_cost: + # sched_op is not part of the sub-schedule - skip + continue + + # Create a cost entry with the new stripe + cost = sched_op.create_scheduler_info(self.nng, stripe) + + # Copy the weight buffering from the reference schedule + cost.buffered_weight_tensor = ref_cost[sched_op].buffered_weight_tensor + + # Estimate performance + cost.cycles = self.estimate_op_performance(sched_op, cost.block_config, sched_op.ofm.shape.depth) + striped_schedule.cost_map[sched_op] = cost + + # Calculate the preceeding Op's stripe + stripe = sched_op.ifm.shape.with_height(stripe.height * sched_op.kernel.stride.y) + + return striped_schedule + + def estimate_schedule_memory_usage(self, schedule: Schedule, non_local_mem_usage: dict): + """Estimates the memory usage of a schedule""" + cost = schedule.cost_map + cascades = schedule.cascades + peak_mem_usage = 0 + for sched_op in self.sched_ops: + if sched_op not in cost: + # sched_op is not part of the sub-schedule - skip + continue + + if cost[sched_op].cascade: + # This Op is part of a cascade - use the cascade's memory usage + cascade_info = cascades[cost[sched_op].cascade] + # Non-local memory usage is already included in the cascade_info + peak_mem_usage = max(cascade_info.mem_usage, peak_mem_usage) + else: + # This Op is not part of a cascade - calculate the memory usage + op_weight_buffer = 0 + if cost[sched_op].buffered_weight_tensor: + op_weight_buffer = cost[sched_op].buffered_weight_tensor.storage_size() + + op_mem_usage = ( + sched_op.ifm_size_in_bytes() + + sched_op.ofm_size_in_bytes() + + op_weight_buffer + + non_local_mem_usage.get(sched_op, 0) ) + peak_mem_usage = max(op_mem_usage, peak_mem_usage) + + return peak_mem_usage + + def optimize_sub_schedule( + self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int + ) -> Schedule: + """Extracts the Ops covered by the given cascade and creates a sub-schedule. The sub-schedule is optimized by + proposing weight buffering and then continously proposing new stripe sizes""" + ref_cost = ref_schedule.cost_map + # Extract the ops that are part of this sub-schedule + start = cascade_info.start + end = cascade_info.end + sub_schedule_ops = self.sched_ops[start : end + 1] + # Create a sub-schedule that contains only the costs for the Ops that are part of the sub-schedule + sub_schedule = Schedule(self.sg, f"SUB_{start}_{end}") + for sched_op in sub_schedule_ops: + sub_schedule.cost_map[sched_op] = ref_cost[sched_op] + + sub_schedule.cascades[end] = cascade_info + # Use the memory snapshot from the reference schedule + sub_schedule.memory_snapshot = ref_schedule.memory_snapshot + + # Calculate memory usage that is live during the sub-schedule but not part of it + time_for_cascade = ref_cost[sub_schedule_ops[0]].time_index + mem_usage_parallel_to_sub_schedule = ref_schedule.memory_snapshot[time_for_cascade] - cascade_info.mem_usage + # If the first Op's IFM has other consumers it has to live throughout the whole sub-schedule whether it's + # included in a cascade or not + persistent_initial_ifm = ( + sub_schedule_ops[0].ifm_size_in_bytes() if len(sub_schedule_ops[0].ifm.connection.consumers) > 1 else 0 + ) + # Calculate non-local-mem-usage per Operator + non_local_mem_usage = {} + for idx, sched_op in enumerate(sub_schedule_ops): + non_local_mem_usage[sched_op] = mem_usage_parallel_to_sub_schedule + if idx != 0: + non_local_mem_usage[sched_op] += persistent_initial_ifm + + cascade_builder = CascadeBuilder(sub_schedule_ops, self.arch.is_spilling_enabled(), non_local_mem_usage) + + # Start by adding buffering + buffered_sub_schedule = self.propose_schedule_buffering(sub_schedule) + # Copy the cascades over from the unbuffered-schedule + buffered_sub_schedule.cascades = sub_schedule.cascades + + # Generate the possible stripings for the final Op in the sub-schedule + final_ofm_shape = sub_schedule_ops[-1].ofm.shape + possible_stripes = [ + final_ofm_shape.with_height(stripe_h) for stripe_h in range(1, final_ofm_shape.height // 2 + 1) + ] + + # Propose different striping - the possible stripes are proposed similarly to a binary search + best_schedule = buffered_sub_schedule + iteration = 0 + while len(possible_stripes) > 1: + proposed_stripe = possible_stripes[len(possible_stripes) // 2] + proposed_schedule = self.propose_schedule_striping( + proposed_stripe, f"OPTIMIZED_{iteration}", buffered_sub_schedule ) - self.n_combinations_searched += len(res) - res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True) - return res + cascade_builder.build_cascades(proposed_schedule, max_template, memory_limit) - def search_weight_streaming_input(self, ps): - sram_used = 0 - for tens in ps.inputs: - if tens.mem_area == self.mem_area: - sram_used += tens.storage_size() - - return self.append_sram(sram_used, self.search_predecessors(ps)) - - def apply_result(self, strat_set, arch): - pass_to_cascaded_pass = dict() - for _, strat in strat_set.strats.items(): - # rewrite the tensors that need this first. e.g. make rolling buffers - inputs = [] - intermediates = [] - outputs = [] - - for ps in strat.passes: - inputs += ps.inputs - intermediates += ps.intermediates - outputs += ps.outputs - - for tens in set(inputs) & set(outputs): - # tensors that are in both sets are intermediates - - # find pass with input/output tensor, and check if they are both placed on NPU - input_placement = None - output_placement = None - for ps in strat.passes: - if tens in ps.inputs: - input_placement = ps.placement - if tens in ps.outputs: - output_placement = ps.placement - if input_placement == output_placement == PassPlacement.Npu: - tens.set_format(TensorFormat.NHCWB16, arch) - - intermediates.append(tens) - inputs.remove(tens) - outputs.remove(tens) - - for rewrite_op, tens, sub_purpose, param_a, param_b, ps in strat.rewrite_list: - if rewrite_op == SchedulerRewrite.ChangeTensorSubPurpose: - tens.mem_area = self.arch.fast_storage_mem_area - tens.mem_type = MemType.Scratch_fast - tens.set_new_sub_purpose(sub_purpose, param_a, param_b) - else: - assert 0, "unknown rewrite_op " + str(rewrite_op) - - is_element_wise = True - for ps in strat.passes: - assert ps.placement == strat.passes[0].placement - if not ps.is_element_wise: - is_element_wise = False + # Check if proposal fits + proposed_schedule_mem_usage = self.estimate_schedule_memory_usage(proposed_schedule, non_local_mem_usage) + if (proposed_schedule_mem_usage) <= memory_limit: + # Remove all possible stripes smaller than this + possible_stripes = possible_stripes[len(possible_stripes) // 2 :] + best_schedule = proposed_schedule + if not proposed_schedule.cascades: + # No cascading required - early exit break - - cascaded_pass = CascadedPass( - strat.passes[0].name, - strat.strat, - inputs, - intermediates, - outputs, - strat.passes, - strat.passes[0].placement, - is_element_wise, - ) - assert strat.sram_used >= 0 - cascaded_pass.sram_used = strat.sram_used - - for idx, ps in enumerate(strat.passes): - assert ps not in pass_to_cascaded_pass - pass_to_cascaded_pass[ps] = cascaded_pass - ps.cascade = cascaded_pass - ps.block_config = strat.block_configs[idx] - - if ps.placement == PassPlacement.Npu: - ps.shared_buffer = shared_buffer_allocation_for_pass_and_block_config( - self.arch, ps, ps.block_config - ) - assert ps.shared_buffer is not None - - sram_used = max(self.non_local_mem_usage[ps.time], 0) - for op in ps.ops: - subgraph = op.attrs.get("subgraph") - if subgraph: - subgraph.base_sram_used = sram_used - - # all passes should have a cascaded pass now - if len(pass_to_cascaded_pass) != len(self.sg.passes): - print( - "mismatch: we have %d passes, but only %d have cascaded passes associated" - % (len(self.sg.passes), len(pass_to_cascaded_pass)) + else: + # Proposal doesn't fit within the limit - remove all possible stripes larger than this + possible_stripes = possible_stripes[: len(possible_stripes) // 2] + + iteration += 1 + + return best_schedule + + def optimize_schedule( + self, schedule: Schedule, max_sched: Schedule, max_template: Schedule, options: SchedulerOptions, + ) -> Schedule: + """Extracts sub-schedules based on the cascades and optimizes them and applies them to the final schedule""" + sram_limit = options.optimization_sram_limit + if max_sched.fast_storage_peak_usage < sram_limit and not self.arch.is_spilling_enabled(): + # Maximum performance schedule fits within the SRAM target + return max_sched + + # Extract the cascades + cascades = [cascade for cascade in schedule.cascades.values()] + for cascade_info in cascades: + # Remove existing cascade from schedule + del schedule.cascades[cascade_info.end] + # Optimize the sub-schedule in this cascade + opt_sub_schedule = self.optimize_sub_schedule(cascade_info, schedule, max_template, sram_limit) + # Update the sub-schedule Op and cascade costs to the full schedule + schedule.cost_map.update(opt_sub_schedule.cost_map) + schedule.cascades.update(opt_sub_schedule.cascades) + + # Update memory snapshot + self.sg.schedule = schedule + self.update_op_memory_snapshot(schedule) + # Propose schedule buffering to the optimized schedule + optimized_sched = self.propose_schedule_buffering(schedule) + # Copy the cascade's metadata from the unbuffered schedule + optimized_sched.cascades = schedule.cascades + return optimized_sched + + def apply_schedule(self, sched: Schedule): + """Applies the given schedule as a final solution""" + for sched_op in self.sched_ops: + op_info = sched.cost_map[sched_op] + cascade_info = sched.cascades.get(op_info.cascade, None) + if cascade_info and sched_op in cascade_info.buffers: + buffer_tens = sched_op.ifm.connection.parent_tens + # Apply memory area and type + buffer_tens.mem_area = self.arch.fast_storage_mem_area + buffer_tens.mem_type = MemType.Scratch_fast + # Apply Rolling buffer + buffer_tens.set_format(TensorFormat.NHCWB16, self.arch) + buffer_tens.set_new_sub_purpose(TensorSubPurpose.RollingBufferY, cascade_info.buffers[sched_op].height) + + sched_op.parent_ps.block_config = op_info.block_config.old_style_representation() + + # Ensure that the src_tensor reference is set correctly + if op_info.buffered_weight_tensor: + op_info.buffered_weight_tensor.src_tensor = op_info.npu_weights_tensor + + def use_fast_storage_for_feature_maps(self, schedule: Schedule, memory_limit: int): + if self.arch.fast_storage_mem_area == self.arch.feature_map_storage_mem_area: + return + + # Force all OFMs to fast-storage + for sched_op in self.sched_ops: + cost = schedule.cost_map[sched_op] + if cost.cascade == 0: + if sched_op.get_dependants(): + ofm_tens = sched_op.ofm.connection.parent_tens + if not any(cons is None for cons in ofm_tens.consumer_list): + ofm_tens.mem_area = self.arch.fast_storage_mem_area + ofm_tens.mem_type = MemType.Scratch_fast + + # Collect live ranges from tensors + memories_list = [(self.arch.fast_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast)))] + lr_graph = live_range.LiveRangeGraph() + for mem_area, mem_type_set in memories_list: + live_range.extract_live_ranges_from_cascaded_passes( + self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum, ) - for ps in self.sg.passes: - if ps not in pass_to_cascaded_pass: - print("%3d pass missing cascaded pass %s" % (ps.time, ps)) - assert len(pass_to_cascaded_pass) == len(self.sg.passes) + # Iterate over live ranges and evict tensors that doesn't fit + fast_storage_snapshot = lr_graph.get_temporal_memory_usage(self.arch.fast_storage_mem_area) + for lr in lr_graph.lrs: + if ( + lr.mem_area == self.arch.fast_storage_mem_area + and max(fast_storage_snapshot[lr.start_time : lr.end_time + 1]) > memory_limit + ): + # Evict tensor to DRAM + for tens in lr.tensors: + if tens.purpose == TensorPurpose.FeatureMap and tens.sub_purpose == TensorSubPurpose.Standard: + # Can only evict unbuffered FeatureMaps + tens.mem_area = self.arch.feature_map_storage_mem_area + tens.mem_type = MemType.Scratch + # Adjust the snapshot + fast_storage_snapshot[lr.start_time : lr.end_time + 1] -= lr.size + + def move_constant_data(self): + """Determine if data, can be moved from permanent storage to another memory area. A move + will generate a DMA command in the high-level command stream""" + for sched_op in self.sched_ops: + parent_op = sched_op.parent_op + is_lut_used = any(inp.purpose == TensorPurpose.LUT for inp in parent_op.inputs) + max_ifm_shram_avail = ( + (self.arch.available_shram_banks(is_lut_used) - self.arch.shram_reserved_output_banks) + * self.arch.shram_bank_size + // 2 + ) - cascaded_passes = [] - if self.sg.placement == PassPlacement.Cpu: - # Retain the pass order for CPU subgraph - cascaded_passes = [ps.cascade for ps in self.sg.passes] - else: - # we have all the passes, but we need to put them in order and build predecessor/successor links. - visit_pass_set = set() + for idx, tens in enumerate(parent_op.inputs): + if tens.mem_type not in (MemType.Scratch, MemType.Scratch_fast): + # Tensor is in permanent storage + # Only when permanent storage differs from feature map storage, there is a point moving the data + if ( + tens.mem_area in self.arch.permanent_storage_mem_area + and self.arch.permanent_storage_mem_area != self.arch.feature_map_storage_mem_area + ) or tens.purpose == TensorPurpose.LUT: + if tens.purpose == TensorPurpose.LUT or ( + tens.purpose == TensorPurpose.FeatureMap + and sched_op.op_type.is_binary_elementwise_op() + and tens.shape != [] + and sched_op.ifm.shape != sched_op.ofm.shape + and tens.storage_size() > max_ifm_shram_avail + ): + only_vector_product_consumers = all( + oper and oper.type.npu_block_type == NpuBlockType.VectorProduct + for oper in tens.consumers() + ) - def visit_pass(ps): - if ps in visit_pass_set: - return - visit_pass_set.add(ps) + if (not only_vector_product_consumers) or tens.purpose == TensorPurpose.LUT: + new_tens = tens.clone_into_fast_storage(self.arch) + if tens.purpose == TensorPurpose.LUT: + new_tens.mem_area = MemArea.Shram + + new_tens.consumer_list.append(parent_op) + parent_op.inputs[idx] = new_tens + sched_op.parent_ps.inputs[idx] = new_tens + + def print_schedule(self, schedule: Schedule): + print(f"Schedule: '{schedule.name}'") + for sched_op in self.sched_ops: + if sched_op not in schedule.cost_map: + # Sub-schedule printing + continue + + op_info = schedule.cost_map[sched_op] + print(f"\t{sched_op.index}: Operation {sched_op.name} - OFM {sched_op.ofm.shape}") + print(f"\t\tType: {sched_op.op_type}") + print(f"\t\tKernel: {sched_op.kernel}") + print(f"{op_info}") + mem_usage = ( + schedule.memory_snapshot[op_info.time_index] + if op_info.time_index < len(schedule.memory_snapshot) + else 0 + ) + print(f"\t\tSRAM Used: {mem_usage} bytes") + + print(f"\tCascades:") + for i, cascade in enumerate(schedule.cascades.values()): + print(f"\t\t{i}: {cascade.start} -> {cascade.end}, size: {cascade.mem_usage}") + + +def _update_tensor_allocation(nng: Graph, arch: ArchitectureFeatures, options): + """ + Creates live ranges and runs tensor allocator for the current schedule + (i.e. sg.schedule for all subgraphs), returns the maximum memory usage + and updates SchedulerOpInfo.mem_usage for all operations in the schedule. + """ + root_sg = nng.get_root_subgraph() + + alloc_list = [] + if arch.is_spilling_enabled(): + mem_alloc_scratch_fast = (arch.fast_storage_mem_area, set((MemType.Scratch_fast,))) + mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch,))) + # Order is important + alloc_list.append(mem_alloc_scratch_fast) + alloc_list.append(mem_alloc_scratch) + else: + mem_alloc_scratch = (arch.feature_map_storage_mem_area, set((MemType.Scratch, MemType.Scratch_fast))) + alloc_list.append(mem_alloc_scratch) + + for mem_area, mem_type_set in alloc_list: + tensor_allocation.allocate_tensors( + nng, + root_sg, + arch, + mem_area, + mem_type_set, + tensor_allocator=options.tensor_allocator, + verbose_allocation=options.verbose_allocation, + cpu_tensor_alignment=options.cpu_tensor_alignment, + ) - cps = ps.cascade - dont_traverse = set(cps.passes) - for ps in cps.passes: - for pred in ps.predecessors: - if pred in dont_traverse: - continue - visit_pass(pred) +def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_options: SchedulerOptions): + """Entry point for the Scheduler""" + # Initialize CPU subgraphs + schedulers = dict() + # Initialize schedulers with max schedule. Only schedule NPU subgraphs + for sg in nng.subgraphs: + if sg.placement != PassPlacement.Npu: + # Create cascaded passes for CPU Ops + cascaded_passes = [] + for idx, ps in enumerate(sg.passes): + cps = CascadedPass( + ps.name, SchedulingStrategy.WeightStream, ps.inputs, [], ps.outputs, [ps], ps.placement, False, + ) + cps.time = idx + ps.cascade = cps cascaded_passes.append(cps) - starting_passes = [ps for ps in self.sg.passes if not ps.successors] - for ps in starting_passes: - visit_pass(ps) - - # reorder so startup init cascaded passes come first - def is_startup_cascaded_pass(cps): - if not cps.passes: - return False - return cps.placement == PassPlacement.StartupInit - - cascaded_passes = [cps for cps in cascaded_passes if is_startup_cascaded_pass(cps)] + [ - cps for cps in cascaded_passes if not is_startup_cascaded_pass(cps) - ] - - self.sg.cascaded_passes = cascaded_passes - self.sg.build_cascaded_pass_links() - - # Check if NHCWB16 and/or fast storage can be used in between cascaded passes - # (NHCWB16 within cascaded passes has been handled earlier in this function) - if self.sg.placement == PassPlacement.Npu: - # Dictionary tensor -> list of ops, containing feature maps that can be attempted - # to be moved to fast storage - fast_storage_tensor_rewrites = {} - last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op - # Memory only passes have no primary_op, so use the last op in ops - if last_op_in_subgraph is None: - last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].ops[-1] - for ps in self.sg.cascaded_passes: - if ps.placement != PassPlacement.Npu: - continue - for output in ps.outputs: - if output.purpose != TensorPurpose.FeatureMap: - continue - - use_NHCWB16 = not output.needs_linear_format - use_fast_storage = True - rewrites = [] - for op in output.consumer_list: - if op is None: - use_NHCWB16 = False - use_fast_storage = False - continue - if op.type == Op.ReduceSum and output.dtype == DataType.int32: - use_NHCWB16 = False - elif op.type == Op.Reshape: - # Using NHCWB16 format for a no-op reshape is only an option if subsequent - # consumers do not also need to perform a reshape or if the OFM is going to - # be processed by CPU operations. No-op reshape consumers with empty lists - # (those that have no consumers, or null-consumers used as list terminators) - # must use normal NHWC output. - def incompatible_consumers(oper): - if oper and oper.type == Op.Reshape: - for consumer in oper.outputs[0].consumer_list: - yield from incompatible_consumers(consumer) - yield not oper or not oper.run_on_npu or oper is last_op_in_subgraph - - if not any(incompatible_consumers(op)): - - def get_rewrites(oper): - if oper and oper.type == Op.Reshape: - for consumer in oper.outputs[0].consumer_list: - yield from get_rewrites(consumer) - yield oper - - rewrites.extend(get_rewrites(op)) - # Detect no-op reshapes by comparing their full input and output tensor shapes. - inshape = op.ifm_shapes[0] - compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)] - use_NHCWB16 &= compatible_shape and all(compatible_shape) - else: - use_NHCWB16 = False - use_fast_storage = False - use_NHCWB16 &= op.run_on_npu - use_fast_storage &= op.run_on_npu - - if use_fast_storage: - fast_storage_tensor_rewrites[output] = rewrites - if use_NHCWB16 and self.options.use_nhcwb16_between_cascaded_passes: - output.set_format(TensorFormat.NHCWB16, arch) - for rewrite_op in rewrites: - rewrite_op.outputs[0].set_format(TensorFormat.NHCWB16, arch) - if arch.is_spilling_enabled(): - # Remember feature maps that can be moved to fast storage for later use - # in use_fast_storage_for_feature_maps - self.sg.scheduling_info["feature_map_rewrites"] = fast_storage_tensor_rewrites - + sg.cascaded_passes = cascaded_passes + else: + # Npu subgraph - create schedule + scheduler = Scheduler(nng, sg, arch, scheduler_options) + schedulers[sg] = scheduler + + scheduler.create_scheduler_representation(arch) + sg.sched_ops = scheduler.sched_ops + scheduler.move_constant_data() + + # Create the Max schedule template + max_schedule_template = scheduler.create_initial_schedule() + scheduler.max_schedule = max_schedule_template + + # Create the optimimised Max schedule + sg.schedule = max_schedule_template + scheduler.update_op_memory_snapshot(max_schedule_template) + opt_max_schedule = scheduler.propose_schedule_buffering(max_schedule_template) + sg.schedule = opt_max_schedule + scheduler.update_op_memory_snapshot(opt_max_schedule) + + # Create Min schedule + min_schedule = scheduler.propose_minimal_schedule() + initial_sram_limit = scheduler_options.optimization_sram_limit + if scheduler_options.optimization_strategy == OptimizationStrategy.Size: + initial_sram_limit = scheduler.min_memory_req + + cascade_builder = CascadeBuilder(scheduler.sched_ops, arch.is_spilling_enabled()) + cascade_builder.build_cascades(min_schedule, max_schedule_template, initial_sram_limit) + sg.schedule = min_schedule + scheduler.update_op_memory_snapshot(min_schedule) + + if scheduler_options.optimization_strategy == OptimizationStrategy.Performance: + # Create an optimized schedule + sg.schedule = scheduler.optimize_schedule( + min_schedule, opt_max_schedule, max_schedule_template, scheduler_options + ) + scheduler.update_op_memory_snapshot(sg.schedule) -def move_scales_to_fast_storage(nng, arch): - for sg in nng.subgraphs: - # IFM streamed ops reads bias tensors several times, move these to fast storage - for cp in sg.cascaded_passes: - if cp.strategy == SchedulingStrategy.IfmStream: - # Calculate SRAM usage - new_size = 0 - all_tens = [] - for ps in cp.passes: - pass_tens = np.array([ps.ifm_tensor, ps.ifm2_tensor, ps.ofm_tensor, ps.weight_tensor]) - pass_tens = np.append(pass_tens, ps.intermediates) - for tens in pass_tens: - if tens and tens.mem_area == MemArea.Sram and tens not in all_tens: - all_tens.append(tens) - new_size += tens.storage_size() - - cp.sram_used = new_size - - for ps in cp.passes: - if ps.scale_tensor: - tens = ps.scale_tensor - - # Find op using scale tensor - op = next((op for op in ps.ops if tens in op.inputs), None) - assert op - - # Create fast storage tensor - new_tens = tens.clone_into_fast_storage(arch) - new_tens.consumer_list = tens.consumer_list.copy() - new_tens.purpose = TensorPurpose.FSBias - new_tens_size = new_tens.storage_size() - - if (cp.sram_used + new_tens_size) <= arch.sram_size: - # Create DMA cmd - dma_cmd = Operation(Op.DMA, tens.ops[0].name + "_dma") - dma_cmd.inputs = [tens] - dma_cmd.set_output_tensor(new_tens) - dma_cmd.attrs["source"] = tens.mem_area - dma_cmd.attrs["destination"] = new_tens.mem_area - dma_cmd.run_on_npu = True - - tens.consumer_list.clear() - tens.consumer_list.append(dma_cmd) - - # Replace tensor and op - idx = op.inputs.index(tens) - op.inputs[idx] = new_tens - - ps.ops.insert(0, dma_cmd) - ps.scale_tensor = new_tens - ps.intermediates.append(new_tens) - ps.cascade.intermediates.append(new_tens) - - cp.sram_used += new_tens_size - - -def schedule_passes(nng, arch, options: SchedulerOptions): + scheduler.apply_schedule(sg.schedule) + scheduler.use_fast_storage_for_feature_maps(sg.schedule, scheduler_options.optimization_sram_limit) - for sg in nng.subgraphs: - sg.base_sram_used = 0 + if scheduler_options.verbose_schedule: + scheduler.print_schedule(sg.schedule) - for sg in nng.subgraphs: - # re-entering the same nodes from different contexts requires us to - # build a simplified directed acyclic (DAG) version of the graph to - # use for traversal, rather than using a visit dictionary. this avoids - # recursing infinitely due to loops. - sg.build_pass_dag_predecessors() - - dps = DynamicProgrammingScheduler(nng, sg, arch, arch.sram_size, options) - - strat_set = dps.search() - - dps.apply_result(strat_set, arch) - - if options.verbose_schedule: - sg.print_cascaded_passes() - - -def _calc_tens_to_cps(sg, tensor_rewrites): - # Determines for each tensor the list of affected cascaded passes, in terms of SRAM consumption. - # Returns dictionary tensor -> list of cascaded passes - # Note: if cascaded passes are A, B, C, D, and a tensor is output - # of A and input to D, then it also consumes SRAM in passes B and C. - if "tens_to_cps" in sg.scheduling_info: - return sg.scheduling_info["tens_to_cps"] - # Determine life-time of tensors - min_index = {} - max_index = {} - index = 0 - cps_list = [cps for cps in sg.cascaded_passes if cps.placement == PassPlacement.Npu] - for cps in cps_list: - for tens in cps.inputs + cps.outputs: - if tens in tensor_rewrites: - min_index[tens] = min(index, min_index.get(tens, len(cps_list))) - max_index[tens] = index - index += 1 - # Convert to affected cps-es - tens_to_cps = {} - for tens in min_index: - tens_to_cps[tens] = cps_list[min_index[tens] : max_index[tens] + 1] - sg.scheduling_info["tens_to_cps"] = tens_to_cps - return tens_to_cps - - -def use_fast_storage_for_feature_maps(sg, sram_limit, arch): - # Attempts to use as much fast storage as possible for feature maps shared between cascaded passes. - tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {}) - tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites) - # Sort tensors first on life-time (smallest first), then on size (biggest first) - tens_list = sorted([(len(tens_to_cps[tens]), -tens.storage_size(), tens.name, tens) for tens in tens_to_cps]) - for _, _, _, tens in tens_list: - cps_list = tens_to_cps[tens] - if len(cps_list) < 1: - continue - sz = tens.storage_size() - fits_in_fast_storage = all([cps.sram_used + sz <= sram_limit for cps in cps_list]) - if fits_in_fast_storage: - tens.mem_area = arch.fast_storage_mem_area - tens.mem_type = MemType.Scratch_fast - tens.set_new_sub_purpose(TensorSubPurpose.Standard, None, None) - assert tens in tensor_rewrites - # Also rewrite reshapes - for rewrite_op in tensor_rewrites[tens]: - tens2 = rewrite_op.outputs[0] - tens2.mem_area = arch.fast_storage_mem_area - tens2.mem_type = MemType.Scratch_fast - tens2.set_new_sub_purpose(TensorSubPurpose.Standard, None, None) - for cps in cps_list: - cps.sram_used += sz - - -def undo_use_fast_storage(sg, arch): - # Undoes the effects of a previous call to use_fast_storage_for_feature_maps - tensor_rewrites = sg.scheduling_info.get("feature_map_rewrites", {}) - tens_to_cps = _calc_tens_to_cps(sg, tensor_rewrites) - mem_area = arch.tensor_storage_mem_area[TensorPurpose.FeatureMap] - for tens, cps_list in tens_to_cps.items(): - if tens.mem_type == MemType.Scratch_fast: - sz = tens.storage_size() - tens.mem_area = mem_area - tens.mem_type = MemType.Scratch - # Also undo reshapes - for rewrite_op in tensor_rewrites[tens]: - tens2 = rewrite_op.outputs[0] - tens2.mem_area = mem_area - tens2.mem_type = MemType.Scratch - for cps in cps_list: - cps.sram_used -= sz + # Evaluate schedule + _update_tensor_allocation(nng, arch, options) -- cgit v1.2.1