diff options
author | Charles Xu <charles.xu@arm.com> | 2020-06-10 10:48:33 +0200 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-06-18 17:53:52 +0100 |
commit | 19515e8fbb1e23c63d6bb963054deb09dae66e88 (patch) | |
tree | 71270a80ee6ded31d1cde0ae2fe778997b7990ce /ethosu/vela | |
parent | c65cb6c0eb925d4acd66febd71c8e19f80737a38 (diff) | |
download | ethos-u-vela-19515e8fbb1e23c63d6bb963054deb09dae66e88.tar.gz |
MLBEDSW-2432: Retain pass order for CPU subgraph
Signed-off-by: Charles Xu <charles.xu@arm.com>
Change-Id: I92b18262608415e84266d2903e17fc5112793a38
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/scheduler.py | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index d8c641a9..e45e3e55 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -885,30 +885,34 @@ class DynamicProgrammingScheduler: print("%3d pass missing cascaded pass %s" % (ps.time, ps)) assert len(pass_to_cascaded_pass) == len(self.sg.passes) - # we have all the passes, but we need to put them in order and build predecessor/successor links. - visit_pass_set = set() cascaded_passes = [] + if self.sg.placement == PassPlacement.Cpu: + # Retain the pass order for CPU subgraph + cascaded_passes = [ps.cascade for ps in self.sg.passes] + else: + # we have all the passes, but we need to put them in order and build predecessor/successor links. + visit_pass_set = set() - def visit_pass(ps): - if ps in visit_pass_set: - return - visit_pass_set.add(ps) + def visit_pass(ps): + if ps in visit_pass_set: + return + visit_pass_set.add(ps) - cps = ps.cascade - dont_traverse = set(cps.passes) + cps = ps.cascade + dont_traverse = set(cps.passes) - for ps in cps.passes: - for pred in ps.predecessors: - if pred in dont_traverse: - continue - visit_pass(pred) + for ps in cps.passes: + for pred in ps.predecessors: + if pred in dont_traverse: + continue + visit_pass(pred) - cascaded_passes.append(cps) + cascaded_passes.append(cps) - starting_passes = [ps for ps in self.sg.passes if not ps.successors] - for ps in starting_passes: - visit_pass(ps) + starting_passes = [ps for ps in self.sg.passes if not ps.successors] + for ps in starting_passes: + visit_pass(ps) # reorder so startup init cascaded passes come first def is_startup_cascaded_pass(cps): |