diff options
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r-- | ethosu/vela/scheduler.py | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index d01942bb..e9f38b4d 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -941,6 +941,29 @@ class Scheduler: return peak_mem_usage + def build_cascades_for_min_schedule(self, min_schedule: Schedule, max_template: Schedule, memory_limit: int): + # Update memory snapshot + self.sg.schedule = min_schedule + self.update_op_memory_snapshot(min_schedule) + + # Calculate residual memory for Min schedule + non_local_mem_usage = {} + for sched_op in self.sched_ops: + time_index = min_schedule.cost_map[sched_op].time_index + + if self.arch.is_spilling_enabled(): + # For Dedicated SRAM only the intermediate buffers are in SRAM, hence op_mem_usage is 0 + op_mem_usage = 0 + else: + # Min schedule only have ifm and ofm in SRAM (no buffered weigth tensors) + op_mem_usage = sched_op.ifm_size_in_bytes() + sched_op.ofm_size_in_bytes() + + non_local_mem_usage[sched_op] = min_schedule.memory_snapshot[time_index] - op_mem_usage + + # Crate cascades for Min schedule + cascade_builder = CascadeBuilder(self.sched_ops, self.arch.is_spilling_enabled(), non_local_mem_usage) + cascade_builder.build_cascades(min_schedule, max_template, memory_limit) + def optimize_sub_schedule( self, cascade_info: CascadeInfo, ref_schedule: Schedule, max_template: Schedule, memory_limit: int ) -> Schedule: @@ -1545,8 +1568,8 @@ def schedule_passes(nng: Graph, arch: ArchitectureFeatures, options, scheduler_o if scheduler_options.optimization_strategy == OptimizationStrategy.Size: initial_sram_limit = scheduler.min_memory_req - cascade_builder = CascadeBuilder(scheduler.sched_ops, arch.is_spilling_enabled()) - cascade_builder.build_cascades(min_schedule, max_schedule_template, initial_sram_limit) + # Build cascades for Min schedule + scheduler.build_cascades_for_min_schedule(min_schedule, max_schedule_template, initial_sram_limit) sg.schedule = min_schedule scheduler.update_op_memory_snapshot(min_schedule) |