aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/scheduler.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 41902d67..e9a93c19 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -954,6 +954,7 @@ class DynamicProgrammingScheduler:
# Check if NHCWB16 can be used in between cascaded passes
# (NHCWB16 within cascaded passes has been handled earlier in this function)
if self.sg.placement == PassPlacement.Npu:
+ last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op
for ps in self.sg.cascaded_passes:
if ps.placement != PassPlacement.Npu:
continue
@@ -975,8 +976,8 @@ class DynamicProgrammingScheduler:
# be processed by CPU operations. No-op reshape consumers with empty lists
# (those that have no consumers, or null-consumers used as list terminators)
# must use normal NHWC output.
- incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape") for consumer in op.outputs[0].consumer_list
- if consumer is not None ]
+ incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape" or (consumer is last_op_in_subgraph))
+ for consumer in op.outputs[0].consumer_list if consumer is not None ]
if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
rewrites.append(op)
else: