diff options
Diffstat (limited to 'ethosu/vela/npu_performance.py')
-rw-r--r-- | ethosu/vela/npu_performance.py | 43 |
1 files changed, 22 insertions, 21 deletions
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py index 9a6e8cd2..d28df97d 100644 --- a/ethosu/vela/npu_performance.py +++ b/ethosu/vela/npu_performance.py @@ -489,7 +489,7 @@ def estimate_memory_bandwidth(arch, mem_area, direction, tensor, block_size: Blo return bw * (max_burst_len / burst_len) -def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False): +def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None, force_outputs_to_fast_storage=False): if block_config is None: block_config = ps.block_config bws = make_bandwidth_array() @@ -723,26 +723,27 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f for tens in dma_op.inputs: cycles[PassCycles.Npu] += tens.storage_size() / arch.memory_bandwidths_per_cycle[mem_area] - # apply the desired rewrites - for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list: - if ps != ps_to_rewrite: - continue - if rewrite_op == SchedulerRewrite.Nop: - pass # these are fine, no bandwidth changes - elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,): - if tens.purpose == TensorPurpose.FeatureMap: - bw = estimate_memory_bandwidth( - arch, - arch.fast_storage_mem_area, - BandwidthDirection.Read, - tens, - ifm_block, - replacement_read_bws[tens], - ) - else: - bw = replacement_read_bws[tens] - bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += bw - replacement_read_bws[tens] = 0 + if rewrite_list is not None: + # apply the desired rewrites + for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list: + if ps != ps_to_rewrite: + continue + if rewrite_op == SchedulerRewrite.Nop: + pass # these are fine, no bandwidth changes + elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,): + if tens.purpose == TensorPurpose.FeatureMap: + bw = estimate_memory_bandwidth( + arch, + arch.fast_storage_mem_area, + BandwidthDirection.Read, + tens, + ifm_block, + replacement_read_bws[tens], + ) + else: + bw = replacement_read_bws[tens] + bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += bw + replacement_read_bws[tens] = 0 for tens in ps.outputs: if force_outputs_to_fast_storage: |