aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Xu <charles.xu@arm.com>2020-06-10 10:48:33 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit19515e8fbb1e23c63d6bb963054deb09dae66e88 (patch)
tree71270a80ee6ded31d1cde0ae2fe778997b7990ce
parentc65cb6c0eb925d4acd66febd71c8e19f80737a38 (diff)
downloadethos-u-vela-19515e8fbb1e23c63d6bb963054deb09dae66e88.tar.gz
MLBEDSW-2432: Retain pass order for CPU subgraph
Signed-off-by: Charles Xu <charles.xu@arm.com> Change-Id: I92b18262608415e84266d2903e17fc5112793a38
-rw-r--r--ethosu/vela/scheduler.py38
1 files changed, 21 insertions, 17 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index d8c641a9..e45e3e55 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -885,30 +885,34 @@ class DynamicProgrammingScheduler:
print("%3d pass missing cascaded pass %s" % (ps.time, ps))
assert len(pass_to_cascaded_pass) == len(self.sg.passes)
- # we have all the passes, but we need to put them in order and build predecessor/successor links.
- visit_pass_set = set()
cascaded_passes = []
+ if self.sg.placement == PassPlacement.Cpu:
+ # Retain the pass order for CPU subgraph
+ cascaded_passes = [ps.cascade for ps in self.sg.passes]
+ else:
+ # we have all the passes, but we need to put them in order and build predecessor/successor links.
+ visit_pass_set = set()
- def visit_pass(ps):
- if ps in visit_pass_set:
- return
- visit_pass_set.add(ps)
+ def visit_pass(ps):
+ if ps in visit_pass_set:
+ return
+ visit_pass_set.add(ps)
- cps = ps.cascade
- dont_traverse = set(cps.passes)
+ cps = ps.cascade
+ dont_traverse = set(cps.passes)
- for ps in cps.passes:
- for pred in ps.predecessors:
- if pred in dont_traverse:
- continue
- visit_pass(pred)
+ for ps in cps.passes:
+ for pred in ps.predecessors:
+ if pred in dont_traverse:
+ continue
+ visit_pass(pred)
- cascaded_passes.append(cps)
+ cascaded_passes.append(cps)
- starting_passes = [ps for ps in self.sg.passes if not ps.successors]
- for ps in starting_passes:
- visit_pass(ps)
+ starting_passes = [ps for ps in self.sg.passes if not ps.successors]
+ for ps in starting_passes:
+ visit_pass(ps)
# reorder so startup init cascaded passes come first
def is_startup_cascaded_pass(cps):