diff options
Diffstat (limited to 'ethosu/vela/scheduler.py')
-rw-r--r-- | ethosu/vela/scheduler.py | 15 |
1 files changed, 6 insertions, 9 deletions
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 5c2ddabb..41e15294 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -37,6 +37,7 @@ from .npu_performance import make_metrics_arrays from .npu_performance import PassCycles from .numeric_util import full_shape from .operation import NpuBlockType +from .operation import Op from .shared_buffer_allocation import find_block_configs_suitable_for_pass_and_shared_buffer from .shared_buffer_allocation import shared_buffer_allocation_for_pass_and_block_config from .tensor import MemArea @@ -254,11 +255,7 @@ class DynamicProgrammingScheduler: self.pareto_max_candidates = 16 self.ifm_stream_npu_blocks = set( - ( - NpuBlockType.ConvolutionMxN, - NpuBlockType.ConvolutionDepthWise, - NpuBlockType.Pooling, - ) + (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling,) ) num_pareto_metrics = 4 @@ -652,7 +649,7 @@ class DynamicProgrammingScheduler: def avoid_for_cascading(self, pred_candidate): for op in pred_candidate.ops: if ( - op.type == "ConcatSliceWrite" + op.type == Op.ConcatSliceWrite and self.arch.feature_map_storage_mem_area != self.arch.fast_storage_mem_area ): # For SRAM spilling, concat op is avoided as predecessor @@ -981,9 +978,9 @@ class DynamicProgrammingScheduler: use_NHCWB16 = False use_fast_storage = False continue - if op.type == "ReduceSum" and output.dtype == DataType.int32: + if op.type == Op.ReduceSum and output.dtype == DataType.int32: use_NHCWB16 = False - elif op.type == "Reshape": + elif op.type == Op.Reshape: # Detect no-op reshapes by comparing their full input and output tensor shapes. inshape = full_shape(4, op.inputs[0].shape, 1) outshape = full_shape(4, op.outputs[0].shape, 1) @@ -995,7 +992,7 @@ class DynamicProgrammingScheduler: incompatible_consumers = [ ( not consumer.run_on_npu - or consumer.type == "Reshape" + or consumer.type == Op.Reshape or (consumer is last_op_in_subgraph) ) for consumer in op.outputs[0].consumer_list |