diff options
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/scheduler.py | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 47f8a47f..24453d8c 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -254,7 +254,11 @@ class DynamicProgrammingScheduler: self.pareto_max_candidates = 16 self.ifm_stream_npu_blocks = set( - (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,) + ( + NpuBlockType.ConvolutionMxN, + NpuBlockType.ConvolutionDepthWise, + NpuBlockType.Pooling, + ) ) num_pareto_metrics = 4 @@ -519,7 +523,7 @@ class DynamicProgrammingScheduler: if self.verbose_pareto_frontier_schedules: print( "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier" - % (self.n_combinations_searched, len(strat_data,)) + % (self.n_combinations_searched, len(strat_data)) ) for idx, (_, strat_set) in enumerate(strat_data): extra = "" @@ -645,13 +649,13 @@ class DynamicProgrammingScheduler: res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True) return res - def avoid_for_spilling(self, pred_candidate): - if self.arch.feature_map_storage_mem_area == self.arch.fast_storage_mem_area: - return False - - # For SRAM spilling, concat op is avoided as predecessor + def avoid_for_cascading(self, pred_candidate): for op in pred_candidate.ops: - if op.type == "ConcatSliceWrite": + if ( + op.type == "ConcatSliceWrite" + and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area + ): + # For SRAM spilling, concat op is avoided as predecessor return True if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1: # The op has consumers in other subgraphs @@ -685,7 +689,7 @@ class DynamicProgrammingScheduler: if pred_candidate.placement == PassPlacement.Npu: if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks: # and it is on the Npu - if not self.avoid_for_spilling(pred_candidate): + if not self.avoid_for_cascading(pred_candidate): # and fusable - it's a candidate pred_pass_list.append(pred_candidate) @@ -896,10 +900,11 @@ class DynamicProgrammingScheduler: ) assert ps.shared_buffer is not None + sram_used = max(self.non_local_mem_usage[ps.time], 0) for op in ps.ops: subgraph = op.attrs.get("subgraph") if subgraph: - subgraph.base_sram_used = cascaded_pass.sram_used + subgraph.base_sram_used = sram_used # all passes should have a cascaded pass now if len(pass_to_cascaded_pass) != len(self.sg.passes): |