From e168b969dc75fc3057413a80fdf0e164ab936910 Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Thu, 5 Nov 2020 17:18:47 +0100 Subject: Vela: estimate memory transfer efficiency Change-Id: I9e00afe0eef0e13fe990e021bcbe3dd0eda4c471 Signed-off-by: Diqing Zhong --- ethosu/vela/npu_performance.py | 150 ++++++++++++++++++++++++++--------------- 1 file changed, 97 insertions(+), 53 deletions(-) diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py index 41d75f45..d8995a1a 100644 --- a/ethosu/vela/npu_performance.py +++ b/ethosu/vela/npu_performance.py @@ -19,7 +19,8 @@ # # Called during scheduling to evaluate different proposals, as well as post-scheduling to provide a final performance # estimate. -import enum +from enum import auto +from enum import IntEnum import numpy as np @@ -35,6 +36,7 @@ from .shared_buffer_allocation import is_acc_40bits_used from .tensor import MemArea from .tensor import shape_num_elements from .tensor import TensorBlockTraversal +from .tensor import TensorFormat from .tensor import TensorPurpose @@ -56,23 +58,21 @@ def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_conf return [height, width] -class PassCycles(enum.IntEnum): +class PassCycles(IntEnum): Npu = 0 - Cpu = 1 - SramAccess = 2 - TotalPerPass = 3 - DramAccess = 4 - OnChipFlashAccess = 5 - OffChipFlashAccess = 6 - Total = 7 - Size = 8 + Cpu = auto() + SramAccess = auto() + DramAccess = auto() + OnChipFlashAccess = auto() + OffChipFlashAccess = auto() + Total = auto() + Size = auto() def display_name(self): return ( "NPU", "CPU", "SRAM Access", - "Total per Pass", "DRAM Access", "On-chip Flash Access", "Off-chip Flash Access", @@ -85,7 +85,6 @@ class PassCycles(enum.IntEnum): "npu", "cpu", "sram_access", - "total_per_pass", "dram_access", "on_chip_flash_access", "off_chip_flash_access", @@ -106,10 +105,10 @@ class PassCycles(enum.IntEnum): ) -class MacCount(enum.IntEnum): +class MacCount(IntEnum): NeuralNetworkMacs = 0 - HardwareMacs = 1 - Size = 2 + HardwareMacs = auto() + Size = auto() def display_name(self): return ("Neural Network Macs", "Hardware Macs", "Size")[self.value] @@ -122,10 +121,10 @@ class MacCount(enum.IntEnum): return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs) -class BandwidthDirection(enum.IntEnum): +class BandwidthDirection(IntEnum): Read = 0 - Write = 1 - Size = 2 + Write = auto() + Size = auto() def display_name(self): return self.name @@ -373,6 +372,46 @@ def estimate_conv_pooling_cycles( return total_cycles +def estimate_memory_bandwidth(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None): + if tensor.format not in (TensorFormat.NHWC, TensorFormat.NHCWB16): + return tensor.bandwidth() if replace_bw is None else replace_bw + + # Estimate memory transfer efficiency by calculating the burst length + # this is related to data format, block shape, and tensor shape, etc. + max_burst_len = 32 if mem_area == MemArea.Sram else 128 + burst_len = 0 + elem_size = tensor.dtype.size_in_bytes() + is_ifm = direction == BandwidthDirection.Read + tens = tensor.clone() + if not tens.avoid_NHCWB16: + tens.set_format(TensorFormat.NHCWB16, arch) + + if tens.format == TensorFormat.NHCWB16: + if tens.get_strides()[1] == block_size.depth: + burst_len = elem_size * block_size.depth * block_size.width + elif is_ifm: + burst_len = 16 * elem_size * block_size.width + else: + burst_len = 16 * elem_size * block_size.width * arch.ncores + else: + assert tens.format == TensorFormat.NHWC + if is_ifm: + if tens.get_strides()[3] == block_size.depth: + burst_len = elem_size * block_size.depth * block_size.width + else: + burst_len = elem_size * block_size.depth + else: + if block_size.depth <= 16 and tens.get_strides()[3] == block_size.depth: + burst_len = elem_size * block_size.depth * block_size.width + else: + burst_len = min(64, 16 * elem_size * arch.ncores, block_size.depth * elem_size) + + burst_len = min(max_burst_len, burst_len) + bw = tens.bandwidth() if replace_bw is None else replace_bw + + return bw * (max_burst_len / burst_len) + + def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False): if block_config is None: block_config = ps.block_config @@ -392,6 +431,9 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f explicit_padding = (0, 0, 0, 0) primary_op = ps.primary_op replacement_read_bws = {} + ofm_block = Block(block_config[1], block_config[0], block_config[3]) + ifm_block = Block(block_config[1], block_config[0], block_config[3]) + if ps.placement == PassPlacement.Cpu: cycles[PassCycles.Cpu] = arch.cpu_cycle_estimate(ps.ops[0]) elif primary_op: @@ -402,6 +444,7 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f block_traversal = TensorBlockTraversal.Default ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm() + ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1) if npu_block_type in set( ( @@ -413,7 +456,6 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f ): # extent the ifm to full dimension ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1)) - ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1) ifm_tensor_bandwidth_shape = numeric_util.full_shape(4, ifm_tensor.bandwidth_shape, 1) batch_size = ifm_tensor_shape[0] @@ -479,9 +521,9 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f strides, ) - blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], block_config[3]) + blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], ofm_block.depth) - n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], block_config[3]) + n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], ofm_block.depth) if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling: n_weight_stages = 1 # force to no reread @@ -527,21 +569,14 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f * block_size[0] * block_size[1] * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time) - * numeric_util.round_up(weight_tensor_shape[3], block_config[3]) + * numeric_util.round_up(weight_tensor_shape[3], ofm_block.depth) * n_kernel_xy ) macs[MacCount.NeuralNetworkMacs] += nn_ops macs[MacCount.HardwareMacs] += num_mac_ops cycles[PassCycles.Npu] = estimate_conv_pooling_cycles( - arch, - npu_block_type, - primary_op, - Block(block_config[1], block_config[0], block_config[3]), - block_traversal, - kernel_dims, - ifm_tensor, - ofm_tensor, + arch, npu_block_type, primary_op, ofm_block, block_traversal, kernel_dims, ifm_tensor, ofm_tensor, ) elif npu_block_type == NpuBlockType.VectorProduct: nn_macs = ( @@ -550,21 +585,15 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f * numeric_util.round_up(weight_tensor.shape[-1], block_config[3]) ) num_mac_ops = nn_macs + block_traversal = weight_tensor.block_traversal cycles[PassCycles.Npu] = estimate_conv_pooling_cycles( - arch, - npu_block_type, - primary_op, - Block(block_config[1], block_config[0], block_config[3]), - weight_tensor.block_traversal, - [1, 1], - ifm_tensor, - ofm_tensor, + arch, npu_block_type, primary_op, ofm_block, block_traversal, [1, 1], ifm_tensor, ofm_tensor, ) macs[MacCount.NeuralNetworkMacs] += nn_macs macs[MacCount.HardwareMacs] += num_mac_ops - blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], block_config[3]) + blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], ofm_block.depth) non_zero_fraction = 1.0 if ifm_tensor.values is not None: @@ -581,6 +610,11 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f arch, npu_block_type, primary_op, ofm_tensor.elements(), ps.ifm_tensor, ps.ofm_tensor, ps.ifm2_tensor ) + ifm_block_depth = get_ifm_block_depth( + npu_block_type, ifm_tensor_shape[3], ifm_tensor.dtype.size_in_bits(), block_traversal, ofm_block.depth + ) + ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, primary_op.kernel) + prev_npu_pass = next((npu_ps for npu_ps in ps.dag_predecessors if npu_ps.placement is PassPlacement.Npu), None) if prev_npu_pass is None: # cycles for DMA ops in first pass @@ -597,14 +631,29 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f if rewrite_op == SchedulerRewrite.Nop: pass # these are fine, no bandwidth changes elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,): - bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens] + if tens.purpose == TensorPurpose.FeatureMap: + bw = estimate_memory_bandwidth( + arch, + arch.fast_storage_mem_area, + BandwidthDirection.Read, + tens, + ifm_block, + replacement_read_bws[tens], + ) + else: + bw = replacement_read_bws[tens] + bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += bw replacement_read_bws[tens] = 0 for tens in ps.outputs: if force_outputs_to_fast_storage: - bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth() + bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth( + arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block + ) else: - bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth() + bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth( + arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block + ) for tens in ps.intermediates: bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth() @@ -617,23 +666,18 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw for tens in ps.inputs: - if tens in replacement_read_bws: - bw = replacement_read_bws[tens] - else: - bw = tens.bandwidth() - - bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw - - cycles[PassCycles.SramAccess] = np.sum(bws[MemArea.Sram]) / arch.memory_bandwidths_per_cycle[MemArea.Sram] - cycles[PassCycles.TotalPerPass] = np.max(cycles[: PassCycles.TotalPerPass]) + bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_bandwidth( + arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, replacement_read_bws.get(tens) + ) # quick build access counts for only current pass, even though these aren't the final numbers - update_summary_cycles(arch, bws, macs, cycles) + update_summary_cycles(arch, bws, cycles) return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple -def update_summary_cycles(arch, bws, macs, cycles): +def update_summary_cycles(arch, bws, cycles): + cycles[PassCycles.SramAccess] = np.sum(bws[MemArea.Sram]) / arch.memory_bandwidths_per_cycle[MemArea.Sram] cycles[PassCycles.DramAccess] = np.sum(bws[MemArea.Dram]) / arch.memory_bandwidths_per_cycle[MemArea.Dram] cycles[PassCycles.OnChipFlashAccess] = ( np.sum(bws[MemArea.OnChipFlash]) / arch.memory_bandwidths_per_cycle[MemArea.OnChipFlash] -- cgit v1.2.1