From 504d6b6b81a564c45e970667e0ace71714bf02dc Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Thu, 17 Sep 2020 12:21:10 +0200 Subject: MLBEDSW-2816: Fix assert in scheduler - Use non local memory as the base sram usage for a subgraph - Make avoid_for_spilling more generic for all mem configs Change-Id: I99cd30fe6a8ba075d5a70dc2138aa0635afaadb3 Signed-off-by: Diqing Zhong --- ethosu/vela/scheduler.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 47f8a47f..24453d8c 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -254,7 +254,11 @@ class DynamicProgrammingScheduler: self.pareto_max_candidates = 16 self.ifm_stream_npu_blocks = set( - (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,) + ( + NpuBlockType.ConvolutionMxN, + NpuBlockType.ConvolutionDepthWise, + NpuBlockType.Pooling, + ) ) num_pareto_metrics = 4 @@ -519,7 +523,7 @@ class DynamicProgrammingScheduler: if self.verbose_pareto_frontier_schedules: print( "Scheduler searched %d combinations and found %d candidate schedules along the pareto frontier" - % (self.n_combinations_searched, len(strat_data,)) + % (self.n_combinations_searched, len(strat_data)) ) for idx, (_, strat_set) in enumerate(strat_data): extra = "" @@ -645,13 +649,13 @@ class DynamicProgrammingScheduler: res = self.filter_pareto_frontier(res, remove_equally_good_candidates=True) return res - def avoid_for_spilling(self, pred_candidate): - if self.arch.feature_map_storage_mem_area == self.arch.fast_storage_mem_area: - return False - - # For SRAM spilling, concat op is avoided as predecessor + def avoid_for_cascading(self, pred_candidate): for op in pred_candidate.ops: - if op.type == "ConcatSliceWrite": + if ( + op.type == "ConcatSliceWrite" + and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area + ): + # For SRAM spilling, concat op is avoided as predecessor return True if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1: # The op has consumers in other subgraphs @@ -685,7 +689,7 @@ class DynamicProgrammingScheduler: if pred_candidate.placement == PassPlacement.Npu: if pred_candidate.npu_block_type in self.ifm_stream_npu_blocks: # and it is on the Npu - if not self.avoid_for_spilling(pred_candidate): + if not self.avoid_for_cascading(pred_candidate): # and fusable - it's a candidate pred_pass_list.append(pred_candidate) @@ -896,10 +900,11 @@ class DynamicProgrammingScheduler: ) assert ps.shared_buffer is not None + sram_used = max(self.non_local_mem_usage[ps.time], 0) for op in ps.ops: subgraph = op.attrs.get("subgraph") if subgraph: - subgraph.base_sram_used = cascaded_pass.sram_used + subgraph.base_sram_used = sram_used # all passes should have a cascaded pass now if len(pass_to_cascaded_pass) != len(self.sg.passes): -- cgit v1.2.1