aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiqing Zhong <diqing.zhong@arm.com>2020-09-28 18:46:22 +0200
committertim.hall <tim.hall@arm.com>2020-11-11 11:14:53 +0000
commit09387e207aa736c464cf95c8a57609aa21b65d44 (patch)
treed9aed24bb0537473b08611622f32401d24daa786
parent897cc14968e017b1f48f376f7f7cefc515c5fe88 (diff)
downloadethos-u-vela-09387e207aa736c464cf95c8a57609aa21b65d44.tar.gz
MLBEDSW-3146: Cycle estimation for conv/pooling ops
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com> Change-Id: Ic6ae795a1626d1cdf63a69d2ff86f7cd898f3134
-rw-r--r--ethosu/vela/npu_performance.py174
-rw-r--r--ethosu/vela/shared_buffer_allocation.py12
2 files changed, 150 insertions, 36 deletions
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 24b4c68..4d221be 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -24,13 +24,14 @@ import enum
import numpy as np
from . import numeric_util
+from .architecture_features import Accelerator
from .architecture_features import Block
-from .architecture_features import SHRAMElements
from .data_type import DataType
from .nn_graph import PassPlacement
from .nn_graph import SchedulerRewrite
from .operation import NpuBlockType
from .operation import Op
+from .shared_buffer_allocation import is_acc_40bits_used
from .tensor import MemArea
from .tensor import shape_num_elements
from .tensor import TensorBlockTraversal
@@ -212,22 +213,20 @@ def get_n_blocks_and_area(
return total_blocks, total_area, block_setup
-def get_output_cycle_estimate(arch, ps):
- primary_op = ps.primary_op
- assert primary_op
- npu_block_type = primary_op.type.npu_block_type
+def get_output_cycle_estimate(
+ arch, npu_block_type, primary_op, num_elems, ifm_tensor, ofm_tensor, ifm2_tensor, use_acc_40bits=False
+):
faf = primary_op.activation
-
- if npu_block_type == NpuBlockType.ElementWise and ps.ifm_tensor.dtype == DataType.int32:
- if ps.ifm2_tensor is None:
+ if npu_block_type == NpuBlockType.ElementWise and ifm_tensor.dtype == DataType.int32:
+ if ifm2_tensor is None:
# Unary op
output_perf_index = 0
else:
# Binary op
output_perf_index = 1
- elif ps.primary_op.type == Op.Mul and ps.ofm_tensor.dtype == DataType.int32:
+ elif primary_op.type == Op.Mul and ofm_tensor.dtype == DataType.int32:
output_perf_index = 2
- elif ps.primary_op.type == Op.Mul or (
+ elif primary_op.type == Op.Mul or (
npu_block_type
in (
NpuBlockType.ConvolutionMxN,
@@ -236,13 +235,13 @@ def get_output_cycle_estimate(arch, ps):
NpuBlockType.ReduceSum,
NpuBlockType.VectorProduct,
)
- and ps.shared_buffer.use_accumulator_element == SHRAMElements.Acc40
+ and use_acc_40bits
):
output_perf_index = 3
- elif ps.primary_op.type in (Op.Add, Op.Sub):
- input_scale = ps.ifm_tensor.quantization.scale_f32
- input2_scale = ps.ifm2_tensor.quantization.scale_f32
- output_scale = ps.ofm_tensor.quantization.scale_f32
+ elif primary_op.type in (Op.Add, Op.Sub):
+ input_scale = ifm_tensor.quantization.scale_f32
+ input2_scale = ifm2_tensor.quantization.scale_f32
+ output_scale = ofm_tensor.quantization.scale_f32
if "resizebilinear" in primary_op.attrs:
output_scale = input2_scale
@@ -253,7 +252,7 @@ def get_output_cycle_estimate(arch, ps):
else:
# Advanced Add/Sub
output_perf_index = 5
- elif ps.primary_op.type.is_maxpool_op():
+ elif primary_op.type.is_maxpool_op():
output_perf_index = 6
else:
output_perf_index = 7
@@ -265,13 +264,95 @@ def get_output_cycle_estimate(arch, ps):
else:
activation_perf_index = 2
- num_elems = ps.outputs[0].elements()
cycle_per_elem = max(
arch.output_cycles_per_elem[output_perf_index], arch.activation_cycles_per_elem[activation_perf_index]
)
return num_elems * cycle_per_elem
+def get_conv_pooling_cycle_estimate(
+ arch, npu_block_type, primary_op, block_config: Block, block_traversal, kernel_dims, ifm_tensor, ofm_tensor
+):
+ num_ublk = (
+ (block_config.width // arch.config.ofm_ublock.width)
+ * (block_config.height // arch.config.ofm_ublock.height)
+ * (block_config.depth // arch.config.ofm_ublock.depth)
+ )
+ num_ofm_blk = 0
+ total_cycles = 0
+ num_elems_blk = block_config.width * block_config.height * block_config.depth
+ ifm_tens_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
+ ofm_tens_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1)
+ use_acc_40bits = is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor)
+
+ sub_kernel_limits = arch.sub_kernel_limits[npu_block_type]
+ n_sub_kernels_y = numeric_util.round_up_divide(kernel_dims[0], sub_kernel_limits[0])
+ n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
+ sub_kernel_x = [
+ min((kernel_dims[1] - i * sub_kernel_limits[1]), sub_kernel_limits[1]) for i in range(n_sub_kernels_x)
+ ]
+ sub_kernel_y = [
+ min((kernel_dims[0] - i * sub_kernel_limits[0]), sub_kernel_limits[0]) for i in range(n_sub_kernels_y)
+ ]
+ sub_kernel_size = (x * y for y in sub_kernel_y for x in sub_kernel_x)
+
+ ifm_blk_depth = 0
+ if npu_block_type != NpuBlockType.Pooling:
+ if ifm_tensor.dtype.size_in_bits() == 16 or block_traversal == TensorBlockTraversal.PartKernelFirst:
+ ifm_blk_depth = 16
+ elif ifm_tensor.dtype.size_in_bits() == 8:
+ ifm_blk_depth = 32
+ else:
+ ifm_blk_depth = 8
+
+ cycles_dpu_blk = 0
+
+ for num_kernel_elems in sub_kernel_size:
+ if npu_block_type == NpuBlockType.Pooling:
+ cycles = max(4, num_kernel_elems) * num_ublk
+ if ifm_tensor.dtype.size_in_bits() == 16 and arch.accelerator_config != Accelerator.Ethos_U55_32:
+ cycles *= 2
+ elif npu_block_type == NpuBlockType.ConvolutionDepthWise:
+ cycles = 4 * numeric_util.round_up_divide(num_kernel_elems, 4) * num_ublk
+ if ifm_tensor.dtype.size_in_bits() == 16:
+ cycles *= 2
+ elif (
+ (npu_block_type == NpuBlockType.ConvolutionMxN and block_traversal != TensorBlockTraversal.PartKernelFirst)
+ or npu_block_type == NpuBlockType.VectorProduct
+ or npu_block_type == NpuBlockType.ReduceSum
+ ):
+ cycles = 4 * num_kernel_elems * num_ublk * numeric_util.round_up_divide(ifm_tens_shape[3], ifm_blk_depth)
+ else:
+ assert block_traversal == TensorBlockTraversal.PartKernelFirst
+ divider = 2 if ifm_tensor.dtype.size_in_bits() == 16 else 4
+ cycles = 4 * (
+ numeric_util.round_up_divide(num_kernel_elems, divider)
+ * numeric_util.round_up_divide(ifm_blk_depth, 8)
+ * num_ublk
+ * numeric_util.round_up_divide(ifm_tens_shape[3], ifm_blk_depth)
+ )
+ cycles_dpu_blk += cycles
+
+ cycles_dpu_blk /= arch.ncores
+
+ num_ofm_blk = (
+ numeric_util.round_up_divide(ofm_tens_shape[1], block_config.height)
+ * numeric_util.round_up_divide(ofm_tens_shape[2], block_config.width)
+ * numeric_util.round_up_divide(ofm_tens_shape[3], block_config.depth)
+ )
+
+ cycles_output_blk = get_output_cycle_estimate(
+ arch, npu_block_type, primary_op, num_elems_blk, ifm_tensor, ofm_tensor, None, use_acc_40bits
+ )
+
+ if cycles_dpu_blk > cycles_output_blk:
+ total_cycles = cycles_dpu_blk * num_ofm_blk + cycles_output_blk
+ else:
+ total_cycles = cycles_output_blk * num_ofm_blk + cycles_dpu_blk
+
+ return total_cycles
+
+
def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], force_outputs_to_fast_storage=False):
if block_config is None:
block_config = ps.block_config
@@ -302,7 +383,12 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f
ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
if npu_block_type in set(
- (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling)
+ (
+ NpuBlockType.ConvolutionMxN,
+ NpuBlockType.ConvolutionDepthWise,
+ NpuBlockType.Pooling,
+ NpuBlockType.ReduceSum,
+ )
):
# extent the ifm to full dimension
ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
@@ -316,12 +402,22 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f
ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
+ block_traversal = TensorBlockTraversal.Default
+
strides = primary_op.attrs["strides"]
if npu_block_type != NpuBlockType.Pooling:
- weight_tensor_shape = weight_tensor.shape
- weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
- weight_tensor_element_size = weight_tensor.element_size()
- weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
+ if npu_block_type == NpuBlockType.ReduceSum:
+ block_traversal = TensorBlockTraversal.DepthFirst
+ weight_tensor_shape = [1, 1, ifm_tensor.shape[3], ofm_tensor.shape[3]]
+ weight_tensor_bandwidth_shape = [0] * 4
+ weight_tensor_element_size = 0
+ weight_tensor_bandwidth_compression_scale = 0.0
+ else:
+ block_traversal = weight_tensor.block_traversal
+ weight_tensor_shape = weight_tensor.shape
+ weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
+ weight_tensor_element_size = weight_tensor.element_size()
+ weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
nn_ops = (
int(ofm_tensor.shape[0])
* int(ofm_tensor.shape[1])
@@ -394,7 +490,7 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f
n_kernel_xy = kernel_dims[0] * kernel_dims[1]
n_input_channels_at_a_time = block_config[2]
- if npu_block_type == NpuBlockType.Pooling or weight_tensor.block_traversal in set(
+ if npu_block_type == NpuBlockType.Pooling or block_traversal in set(
(TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
):
n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
@@ -416,14 +512,18 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f
* n_kernel_xy
)
- if npu_block_type == NpuBlockType.Pooling:
- # TODO: improve pooling estimation
- cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle / 2
- else:
- cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
macs[MacCount.NeuralNetworkMacs] += nn_ops
macs[MacCount.HardwareMacs] += num_mac_ops
-
+ cycles[PassCycles.Dpu] = get_conv_pooling_cycle_estimate(
+ arch,
+ npu_block_type,
+ primary_op,
+ Block(block_config[1], block_config[0], block_config[3]),
+ block_traversal,
+ kernel_dims,
+ ifm_tensor,
+ ofm_tensor,
+ )
elif npu_block_type == NpuBlockType.VectorProduct:
nn_macs = (
ifm_tensor.shape[0]
@@ -432,7 +532,16 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f
)
num_mac_ops = nn_macs
- cycles[PassCycles.Dpu] = num_mac_ops / arch.num_macs_per_cycle
+ cycles[PassCycles.Dpu] = get_conv_pooling_cycle_estimate(
+ arch,
+ npu_block_type,
+ primary_op,
+ Block(block_config[1], block_config[0], block_config[3]),
+ weight_tensor.block_traversal,
+ [1, 1],
+ ifm_tensor,
+ ofm_tensor,
+ )
macs[MacCount.NeuralNetworkMacs] += nn_macs
macs[MacCount.HardwareMacs] += num_mac_ops
@@ -449,8 +558,9 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=[], f
weight_read_multiple = non_zero_fraction
elif npu_block_type == NpuBlockType.ElementWise:
# Work out how many elements we have and calculate performance.
- cycles[PassCycles.ElementWise] = get_output_cycle_estimate(arch, ps)
-
+ cycles[PassCycles.ElementWise] = get_output_cycle_estimate(
+ arch, npu_block_type, primary_op, ofm_tensor.elements(), ps.ifm_tensor, ps.ofm_tensor, ps.ifm2_tensor
+ )
# apply the desired rewrites
for rewrite_op, tens, _, _, _, ps_to_rewrite in rewrite_list:
if ps != ps_to_rewrite:
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 484c34b..51fb168 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -37,9 +37,6 @@ class SharedBufferAllocation:
self.banks_required = np.zeros(SharedBufferArea.Size)
ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
- tensors = [t for t in (ifm_tensor, ifm2_tensor, ofm_tensor) if t is not None]
- scales = [t.quantization.scale_f32 for t in tensors if t.quantization is not None]
- has_scale = len(tensors) == len(scales) and None not in scales
self.kernel = Kernel(1, 1)
self.is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise
@@ -81,7 +78,7 @@ class SharedBufferAllocation:
self.ifm_count = 1
if self.ifm_bits == 16:
- if ps.npu_block_type != NpuBlockType.Pooling and has_scale:
+ if is_acc_40bits_used(ps.npu_block_type, ifm_tensor, ofm_tensor, ifm2_tensor):
self.use_accumulator_element = SHRAMElements.Acc40
self.use_ifm_element = self.use_ifm_element + 1
assert (self.use_ifm_element == SHRAMElements.IFM16) or (
@@ -171,6 +168,13 @@ class SharedBufferAllocation:
)
+def is_acc_40bits_used(npu_block_type, ifm_tensor, ofm_tensor, ifm2_tensor=None):
+ tensors = [t for t in (ifm_tensor, ifm2_tensor, ofm_tensor) if t is not None]
+ scales = [t.quantization.scale_f32 for t in tensors if t.quantization is not None]
+ has_scale = len(tensors) == len(scales) and None not in scales
+ return npu_block_type != NpuBlockType.Pooling and has_scale
+
+
def shared_buffer_allocation_for_pass_and_block_config(arch, ps, block_config):
alloc = SharedBufferAllocation(arch, ps)
assert (alloc.ifm_block_depth == block_config[2]) or alloc.is_equal_depth_op