From 19515e8fbb1e23c63d6bb963054deb09dae66e88 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Wed, 10 Jun 2020 10:48:33 +0200 Subject: MLBEDSW-2432: Retain pass order for CPU subgraph Signed-off-by: Charles Xu Change-Id: I92b18262608415e84266d2903e17fc5112793a38 --- ethosu/vela/scheduler.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) (limited to 'ethosu/vela/scheduler.py') diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index d8c641a9..e45e3e55 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -885,30 +885,34 @@ class DynamicProgrammingScheduler: print("%3d pass missing cascaded pass %s" % (ps.time, ps)) assert len(pass_to_cascaded_pass) == len(self.sg.passes) - # we have all the passes, but we need to put them in order and build predecessor/successor links. - visit_pass_set = set() cascaded_passes = [] + if self.sg.placement == PassPlacement.Cpu: + # Retain the pass order for CPU subgraph + cascaded_passes = [ps.cascade for ps in self.sg.passes] + else: + # we have all the passes, but we need to put them in order and build predecessor/successor links. + visit_pass_set = set() - def visit_pass(ps): - if ps in visit_pass_set: - return - visit_pass_set.add(ps) + def visit_pass(ps): + if ps in visit_pass_set: + return + visit_pass_set.add(ps) - cps = ps.cascade - dont_traverse = set(cps.passes) + cps = ps.cascade + dont_traverse = set(cps.passes) - for ps in cps.passes: - for pred in ps.predecessors: - if pred in dont_traverse: - continue - visit_pass(pred) + for ps in cps.passes: + for pred in ps.predecessors: + if pred in dont_traverse: + continue + visit_pass(pred) - cascaded_passes.append(cps) + cascaded_passes.append(cps) - starting_passes = [ps for ps in self.sg.passes if not ps.successors] - for ps in starting_passes: - visit_pass(ps) + starting_passes = [ps for ps in self.sg.passes if not ps.successors] + for ps in starting_passes: + visit_pass(ps) # reorder so startup init cascaded passes come first def is_startup_cascaded_pass(cps): -- cgit v1.2.1