diff options
-rw-r--r-- | ethosu/vela/scheduler.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 41902d67..e9a93c19 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -954,6 +954,7 @@ class DynamicProgrammingScheduler: # Check if NHCWB16 can be used in between cascaded passes # (NHCWB16 within cascaded passes has been handled earlier in this function) if self.sg.placement == PassPlacement.Npu: + last_op_in_subgraph = self.sg.cascaded_passes[-1].passes[-1].primary_op for ps in self.sg.cascaded_passes: if ps.placement != PassPlacement.Npu: continue @@ -975,8 +976,8 @@ class DynamicProgrammingScheduler: # be processed by CPU operations. No-op reshape consumers with empty lists # (those that have no consumers, or null-consumers used as list terminators) # must use normal NHWC output. - incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape") for consumer in op.outputs[0].consumer_list - if consumer is not None ] + incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape" or (consumer is last_op_in_subgraph)) + for consumer in op.outputs[0].consumer_list if consumer is not None ] if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers): rewrites.append(op) else: |