aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/scheduler.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r--ethosu/vela/scheduler.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index e9a93c19..47f8a47f 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -35,6 +35,7 @@ from .npu_performance import make_cycles_array
from .npu_performance import make_macs_array
from .npu_performance import make_metrics_arrays
from .npu_performance import PassCycles
+from .numeric_util import full_shape
from .operation import NpuBlockType
from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer
from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config
@@ -43,7 +44,7 @@ from .tensor import MemType
from .tensor import TensorFormat
from .tensor import TensorPurpose
from .tensor import TensorSubPurpose
-from .numeric_util import full_shape
+
class ParetoMetric(enum.Enum):
BwCycMem = 1
@@ -652,6 +653,9 @@ class DynamicProgrammingScheduler:
for op in pred_candidate.ops:
if op.type == "ConcatSliceWrite":
return True
+ if len(op.outputs) > 1 or len(op.outputs[0].consumer_list) > 1:
+ # The op has consumers in other subgraphs
+ return True
return False
def search_ifm_streaming_partial(self, ps, block_config):
@@ -976,8 +980,15 @@ class DynamicProgrammingScheduler:
# be processed by CPU operations. No-op reshape consumers with empty lists
# (those that have no consumers, or null-consumers used as list terminators)
# must use normal NHWC output.
- incompatible_consumers = [ (not consumer.run_on_npu or consumer.type == "Reshape" or (consumer is last_op_in_subgraph))
- for consumer in op.outputs[0].consumer_list if consumer is not None ]
+ incompatible_consumers = [
+ (
+ not consumer.run_on_npu
+ or consumer.type == "Reshape"
+ or (consumer is last_op_in_subgraph)
+ )
+ for consumer in op.outputs[0].consumer_list
+ if consumer is not None
+ ]
if (outshape == inshape) and incompatible_consumers and not any(incompatible_consumers):
rewrites.append(op)
else: