aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiqing Zhong <diqing.zhong@arm.com>2020-12-08 13:08:48 +0100
committerDiqing Zhong <diqing.zhong@arm.com>2020-12-09 15:20:38 +0100
commit69aadd052588eb53a257e8f7431ed858161b3286 (patch)
tree4489ef3ad631a71a057b36fefa83afd83a93a8e7
parent2fa40ae3e0e17c899e847c8ad8decd1ec0d9bfcd (diff)
downloadethos-u-vela-69aadd052588eb53a257e8f7431ed858161b3286.tar.gz
Vela: bandwidth calculation improvements
- Combine conv and vector_product calculation - Remove internal bandwidth - Remove blocks and hw_macs from report - Use scaled_bws for cycle estimation Related to: MLBEDSW-3598 Change-Id: I1927a8311ec563f68115e0f2ed077806b86fd717 Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
-rw-r--r--ethosu/vela/npu_performance.py248
-rw-r--r--ethosu/vela/scheduler.py5
-rw-r--r--ethosu/vela/stats_writer.py92
3 files changed, 89 insertions, 256 deletions
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index 2d7a1b09..8ada1e23 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -90,22 +90,6 @@ class PassCycles(IntEnum):
)
-class MacCount(IntEnum):
- NeuralNetworkMacs = 0
- HardwareMacs = auto()
- Size = auto()
-
- def display_name(self):
- return ("Neural Network Macs", "Hardware Macs", "Size")[self.value]
-
- def identifier_name(self):
- return ("nn_macs", "hardware_macs", "size")[self.value]
-
- @staticmethod
- def all():
- return (MacCount.NeuralNetworkMacs, MacCount.HardwareMacs)
-
-
class BandwidthDirection(IntEnum):
Read = 0
Write = auto()
@@ -126,77 +110,18 @@ def make_bandwidth_array():
return np.zeros((MemArea.Size, TensorPurpose.Size, BandwidthDirection.Size))
-def make_macs_array():
- return np.zeros(MacCount.Size, np.int)
-
-
def make_cycles_array():
return np.zeros(PassCycles.Size)
def make_metrics_arrays():
- return (make_bandwidth_array(), make_macs_array(), make_cycles_array())
-
-
-def get_n_blocks_and_area(
- ifm_brick_size, ifm_height_width, orig_skirt, clamped_skirt, block_config, min_block_size, strides
-):
-
- ifm_block_config = (block_config[0] * strides[1], block_config[1] * strides[2])
-
- n_normal_blocks = []
- remainder_size = []
- for i in range(2):
- non_skirt_dim = ifm_height_width[i] - orig_skirt[i] - orig_skirt[2 + i]
- n_blocks = non_skirt_dim // ifm_block_config[i]
- n_normal_blocks.append(n_blocks)
- remainder_dim = numeric_util.round_up(
- ((non_skirt_dim - n_blocks * ifm_block_config[i] - 1) // strides[i + 1]) + 1, min_block_size[i]
- )
- remainder_size.append(remainder_dim)
-
- # this will actually calculate reads into the edge padding.
-
- # there are four cases in total, handling the edges that will not fill a complete block.
-
- # 0000000001
- # 0000000001
- # 0000000001
- # 0000000001
- # 0000000001
- # 0000000001
- # 2222222223
- total_blocks = 0
- total_area = 0
-
- block_setup = (
- (n_normal_blocks[0] * n_normal_blocks[1], block_config),
- (1 * n_normal_blocks[1], (remainder_size[0], block_config[1])),
- (n_normal_blocks[0] * 1, (block_config[0], remainder_size[1])),
- (1 * 1, remainder_size),
- )
-
- for n_blocks, block_size in block_setup:
- if block_size[0] == 0 or block_size[1] == 0:
- continue
- read_dims = [0, 0]
- for i in range(2):
- read_dims[i] = (
- numeric_util.round_up(clamped_skirt[i], ifm_brick_size[i + 1])
- + block_size[i] * strides[i + 1]
- + numeric_util.round_up(clamped_skirt[2 + i], ifm_brick_size[i + 1])
- )
- assert n_blocks >= 0
- total_blocks += n_blocks
- total_area += n_blocks * read_dims[0] * read_dims[1]
- assert total_blocks >= 1
- return total_blocks, total_area, block_setup
+ return (make_bandwidth_array(), 0, make_cycles_array())
def get_ifm_block_depth(npu_block_type, ifm_depth, ifm_elemwidth, block_traversal, ofm_blk_depth):
ifm_blk_depth = ofm_blk_depth
- if npu_block_type == NpuBlockType.ConvolutionMxN or npu_block_type == NpuBlockType.ReduceSum:
+ if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
if ifm_elemwidth == 16 or block_traversal == TensorBlockTraversal.PartKernelFirst:
ifm_blk_depth = 16
elif ifm_elemwidth == 8:
@@ -213,11 +138,11 @@ def get_minimal_cmd_cycles(arch, ifm_tensor, ofm_tensor, ifm_blk: Block, ofm_blk
ifm_tens_blk = Tensor((1, ifm_blk.height, ifm_blk.width, ifm_blk.depth), ifm_tensor.dtype, "ifm_blk")
ofm_tens_blk = Tensor((1, ofm_blk.height, ofm_blk.width, ofm_blk.depth), ofm_tensor.dtype, "ofm_blk")
cycles_ifm_blk = (
- estimate_memory_bandwidth(arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk)
+ estimate_memory_transfer_efficiency(arch, ifm_tensor.mem_area, BandwidthDirection.Read, ifm_tens_blk, ifm_blk)
/ arch.memory_bandwidths_per_cycle[ifm_tensor.mem_area]
)
cycles_ofm_blk = (
- estimate_memory_bandwidth(arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk)
+ estimate_memory_transfer_efficiency(arch, ofm_tensor.mem_area, BandwidthDirection.Write, ofm_tens_blk, ofm_blk)
/ arch.memory_bandwidths_per_cycle[ofm_tensor.mem_area]
)
return (
@@ -449,7 +374,7 @@ def estimate_conv_pooling_cycles(
return total_cycles
-def estimate_memory_bandwidth(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None):
+def estimate_memory_transfer_efficiency(arch, mem_area, direction, tensor, block_size: Block, replace_bw=None):
if tensor.format not in (TensorFormat.NHWC, TensorFormat.NHCWB16):
return tensor.bandwidth() if replace_bw is None else replace_bw
@@ -493,18 +418,15 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
if block_config is None:
block_config = ps.block_config
bws = make_bandwidth_array()
- macs = make_macs_array()
+ scaled_bws = make_bandwidth_array() # scaled bw with memory transfer efficiency
+ macs = 0
cycles = make_cycles_array()
- blocks = 0
ifm_read_multiple = 1
weight_read_multiple = 0
if ps.placement in (PassPlacement.MemoryOnly, PassPlacement.StartupInit):
- return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple # nothing real happening in this pass
-
- min_block_size = arch.min_block_sizes[ps.npu_block_type]
+ return bws, macs, cycles, ifm_read_multiple, weight_read_multiple # nothing real happening in this pass
- skirt = (0, 0, 0, 0)
explicit_padding = (0, 0, 0, 0)
primary_op = ps.primary_op
replacement_read_bws = {}
@@ -512,13 +434,13 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
ifm_block = Block(block_config[1], block_config[0], block_config[3])
if ps.placement == PassPlacement.Npu and primary_op:
- skirt = primary_op.attrs.get("skirt", skirt)
explicit_padding = primary_op.attrs.get("explicit_padding", explicit_padding)
assert primary_op.type.npu_block_type == ps.npu_block_type
npu_block_type = primary_op.type.npu_block_type
ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
ifm_tensor_shape = numeric_util.full_shape(4, ifm_tensor.shape, 1)
+ ofm_tensor_shape = numeric_util.full_shape(4, ofm_tensor.shape, 1)
if npu_block_type == NpuBlockType.ReduceSum:
block_traversal = TensorBlockTraversal.DepthFirst
@@ -540,21 +462,17 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
if npu_block_type in (
NpuBlockType.ConvolutionMxN,
NpuBlockType.ConvolutionDepthWise,
+ NpuBlockType.VectorProduct,
NpuBlockType.Pooling,
NpuBlockType.ReduceSum,
):
# extent the ifm to full dimension
- ifm_tensor_brick_size = tuple(numeric_util.full_shape(4, list(ifm_tensor.brick_size), 1))
- ifm_tensor_bandwidth_shape = numeric_util.full_shape(4, ifm_tensor.bandwidth_shape, 1)
-
batch_size = ifm_tensor_shape[0]
- ifm_depth = ifm_tensor_bandwidth_shape[3]
# add in padding
ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2] # height += top and bottom
ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3] # width += left and right
- strides = primary_op.attrs["strides"]
if npu_block_type != NpuBlockType.Pooling:
if npu_block_type == NpuBlockType.ReduceSum:
weight_tensor_shape = [1, 1, ifm_tensor.shape[3], ofm_tensor.shape[3]]
@@ -562,14 +480,16 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
weight_tensor_element_size = 0
weight_tensor_bandwidth_compression_scale = 0.0
else:
- weight_tensor_shape = weight_tensor.shape
- weight_tensor_bandwidth_shape = weight_tensor.bandwidth_shape
+ # For Vector product, weight format of IO is extended to HWIO, with H=W=1
+ weight_tensor_shape = numeric_util.full_shape(4, weight_tensor.shape, 1)
+ weight_tensor_bandwidth_shape = numeric_util.full_shape(4, weight_tensor.bandwidth_shape, 1)
weight_tensor_element_size = weight_tensor.element_size()
weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
+
nn_ops = (
- int(ofm_tensor.shape[0])
- * int(ofm_tensor.shape[1])
- * int(ofm_tensor.shape[2])
+ int(ofm_tensor_shape[0])
+ * int(ofm_tensor_shape[1])
+ * int(ofm_tensor_shape[2])
* int(weight_tensor_shape[0])
* int(weight_tensor_shape[1])
* int(weight_tensor_shape[2])
@@ -595,72 +515,25 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
n_sub_kernels_x = numeric_util.round_up_divide(kernel_dims[1], sub_kernel_limits[1])
n_sub_kernels = n_sub_kernels_y * n_sub_kernels_x
- clamped_skirt = list(skirt)
- clamped_skirt[2] = min(clamped_skirt[2], sub_kernel_limits[0] - 1 - clamped_skirt[0])
- clamped_skirt[3] = min(clamped_skirt[3], sub_kernel_limits[1] - 1 - clamped_skirt[1])
- n_blocks, area, block_setup = get_n_blocks_and_area(
- ifm_tensor_brick_size,
- ifm_tensor_shape[1:3],
- skirt,
- clamped_skirt,
- block_config,
- min_block_size,
- strides,
- )
-
- blocks = n_blocks * numeric_util.round_up_divide(weight_tensor_shape[3], ofm_block.depth)
-
- n_weight_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], ofm_block.depth)
- if npu_block_type == NpuBlockType.ConvolutionDepthWise or npu_block_type == NpuBlockType.Pooling:
- n_weight_stages = 1 # force to no reread
+ n_full_depth_stages = numeric_util.round_up_divide(weight_tensor_bandwidth_shape[3], ofm_block.depth)
+ if npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.Pooling):
+ n_full_depth_stages = 1 # force to no reread
- ifm_tensor_bw = (
- n_sub_kernels
- * batch_size
- * area
- * ifm_depth
- * n_weight_stages
- * ifm_tensor.element_size()
- * ifm_tensor.bandwidth_compression_scale
- )
- replacement_read_bws[ifm_tensor] = ifm_tensor_bw
- ifm_read_multiple = n_weight_stages
+ ifm_read_multiple = n_sub_kernels * n_full_depth_stages
+ replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth() * ifm_read_multiple
+ weight_read_multiple = numeric_util.round_up_divide(
+ ofm_tensor_shape[1], ofm_block.height
+ ) * numeric_util.round_up_divide(ofm_tensor_shape[2], ofm_block.width)
replacement_read_bws[weight_tensor] = (
batch_size
* shape_num_elements(weight_tensor_bandwidth_shape)
* weight_tensor_element_size
* weight_tensor_bandwidth_compression_scale
- * n_blocks
- ) # read once per block and batch
- weight_read_multiple = n_blocks
-
- n_kernel_xy = kernel_dims[0] * kernel_dims[1]
- n_input_channels_at_a_time = block_config[2]
-
- if (npu_block_type == NpuBlockType.Pooling) or (
- block_traversal in (TensorBlockTraversal.PartKernelFirst, TensorBlockTraversal.DepthWise)
- ):
- n_input_channels_at_a_time = numeric_util.round_up_divide(n_input_channels_at_a_time, 4)
- n_kernel_xy = max(
- n_kernel_xy, 4
- ) # need at least 4, as this is the minimum duty cycle for secondary accumulator writes
- if weight_tensor is not None:
- n_kernel_xy = numeric_util.round_up(n_kernel_xy, 4) # weights need to be read in blocks of 4
-
- num_mac_ops = 0
- for n_blocks_for_size, block_size in block_setup:
- num_mac_ops += (
- batch_size
- * n_blocks_for_size
- * block_size[0]
- * block_size[1]
- * numeric_util.round_up(weight_tensor_shape[2], n_input_channels_at_a_time)
- * numeric_util.round_up(weight_tensor_shape[3], ofm_block.depth)
- * n_kernel_xy
- )
- macs[MacCount.NeuralNetworkMacs] += nn_ops
- macs[MacCount.HardwareMacs] += num_mac_ops
+ * weight_read_multiple
+ )
+
+ macs += nn_ops
cycles[PassCycles.Npu] = estimate_conv_pooling_cycles(
arch,
npu_block_type,
@@ -673,31 +546,6 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
ofm_tensor,
ps.scale_tensor,
)
- elif npu_block_type == NpuBlockType.VectorProduct:
- nn_macs = (
- ifm_tensor.shape[0]
- * numeric_util.round_up(weight_tensor.shape[-2], block_config[2])
- * numeric_util.round_up(weight_tensor.shape[-1], block_config[3])
- )
- num_mac_ops = nn_macs
-
- cycles[PassCycles.Npu] = estimate_conv_pooling_cycles(
- arch, npu_block_type, primary_op, ifm_block, ofm_block, block_traversal, [1, 1], ifm_tensor, ofm_tensor,
- )
- macs[MacCount.NeuralNetworkMacs] += nn_macs
- macs[MacCount.HardwareMacs] += num_mac_ops
-
- blocks = 1 * numeric_util.round_up_divide(weight_tensor.shape[-1], ofm_block.depth)
-
- non_zero_fraction = 1.0
- if ifm_tensor.values is not None:
- nz_vector = np.amax(ifm_tensor.values != 0, axis=0) # max across batch axis
- non_zero_fraction = np.average(nz_vector)
-
- replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth()
- replacement_read_bws[weight_tensor] = weight_tensor.bandwidth() * non_zero_fraction
- ifm_read_multiple = 1
- weight_read_multiple = non_zero_fraction
elif npu_block_type == NpuBlockType.ElementWise:
# Work out how many elements we have and calculate performance.
cycles[PassCycles.Npu] = estimate_output_cycles(
@@ -729,8 +577,9 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
if rewrite_op == SchedulerRewrite.Nop:
pass # these are fine, no bandwidth changes
elif rewrite_op in (SchedulerRewrite.ChangeTensorSubPurpose,):
+ bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += replacement_read_bws[tens]
if tens.purpose == TensorPurpose.FeatureMap:
- bw = estimate_memory_bandwidth(
+ scaled_bw = estimate_memory_transfer_efficiency(
arch,
arch.fast_storage_mem_area,
BandwidthDirection.Read,
@@ -739,22 +588,27 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
replacement_read_bws[tens],
)
else:
- bw = replacement_read_bws[tens]
- bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += bw
+ scaled_bw = replacement_read_bws[tens]
+ scaled_bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Read] += scaled_bw
replacement_read_bws[tens] = 0
for tens in ps.outputs:
if force_outputs_to_fast_storage:
- bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth(
+ bws[arch.fast_storage_mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
+ scaled_bws[arch.fast_storage_mem_area][tens.purpose][
+ BandwidthDirection.Write
+ ] += estimate_memory_transfer_efficiency(
arch, arch.fast_storage_mem_area, BandwidthDirection.Write, tens, ofm_block
)
else:
- bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_bandwidth(
+ bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
+ scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += estimate_memory_transfer_efficiency(
arch, tens.mem_area, BandwidthDirection.Write, tens, ofm_block
)
for tens in ps.intermediates:
bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
+ scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Write] += tens.bandwidth()
if tens in replacement_read_bws:
bw = replacement_read_bws[tens]
@@ -762,16 +616,23 @@ def performance_metrics_for_pass(arch, ps, block_config=None, rewrite_list=None,
bw = tens.bandwidth()
bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
+ scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
for tens in ps.inputs:
- bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_bandwidth(
- arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, replacement_read_bws.get(tens)
+ if tens in replacement_read_bws:
+ bw = replacement_read_bws[tens]
+ else:
+ bw = tens.bandwidth()
+
+ bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += bw
+ scaled_bws[tens.mem_area][tens.purpose][BandwidthDirection.Read] += estimate_memory_transfer_efficiency(
+ arch, tens.mem_area, BandwidthDirection.Read, tens, ifm_block, bw
)
# quick build access counts for only current pass, even though these aren't the final numbers
- update_summary_cycles(arch, bws, cycles)
+ update_summary_cycles(arch, scaled_bws, cycles)
- return bws, macs, cycles, blocks, ifm_read_multiple, weight_read_multiple
+ return bws, macs, cycles, ifm_read_multiple, weight_read_multiple
def update_summary_cycles(arch, bws, cycles):
@@ -794,15 +655,14 @@ def collate_stats_for_cascaded_pass(arch, bws, macs, cycles):
def performance_for_cascaded_pass(arch, cps):
total_bws = make_bandwidth_array()
- total_macs = make_macs_array()
+ total_macs = 0
total_cycles = make_cycles_array()
for ps in cps.passes:
- bws, macs, cycles, blocks, _, _ = performance_metrics_for_pass(arch, ps)
+ bws, macs, cycles, _, _ = performance_metrics_for_pass(arch, ps)
ps.bandwidths = bws
ps.macs = macs
ps.cycles = cycles
- ps.n_blocks = blocks
total_bws += bws
total_macs += macs
total_cycles += cycles
@@ -816,7 +676,7 @@ def performance_for_cascaded_pass(arch, cps):
def calc_performance_for_network(nng, arch):
total_bws = make_bandwidth_array()
- total_macs = np.zeros(MacCount.Size)
+ total_macs = 0
total_cycles = np.zeros(PassCycles.Size)
for sg in nng.subgraphs:
diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py
index 977eb58e..2c10640b 100644
--- a/ethosu/vela/scheduler.py
+++ b/ethosu/vela/scheduler.py
@@ -32,7 +32,6 @@ from .nn_graph import SchedulerRewrite
from .nn_graph import SchedulingStrategy
from .npu_performance import make_bandwidth_array
from .npu_performance import make_cycles_array
-from .npu_performance import make_macs_array
from .npu_performance import make_metrics_arrays
from .npu_performance import PassCycles
from .numeric_util import full_shape
@@ -108,7 +107,7 @@ class Strategy:
return False
if (self.bws != other.bws).any():
return False
- if (self.macs != other.macs).any():
+ if self.macs != other.macs:
return False
if (self.cycles != other.cycles).any():
return False
@@ -211,7 +210,7 @@ class StrategySet:
empty_strategy = Strategy(
- SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), make_macs_array(), make_cycles_array(), 0
+ SchedulingStrategy.Unknown, None, [], [], [], make_bandwidth_array(), 0, make_cycles_array(), 0
)
INFINITY = 1e30
diff --git a/ethosu/vela/stats_writer.py b/ethosu/vela/stats_writer.py
index 494b25e7..02d95d81 100644
--- a/ethosu/vela/stats_writer.py
+++ b/ethosu/vela/stats_writer.py
@@ -22,7 +22,6 @@ import numpy as np
from .nn_graph import PassPlacement
from .npu_performance import BandwidthDirection
-from .npu_performance import MacCount
from .npu_performance import PassCycles
from .numeric_util import round_up_to_int
from .operation import Op
@@ -70,7 +69,7 @@ def write_summary_metrics_csv(nng, summary_filename, arch):
mem_area.identifier_name() + "_total_bytes",
]
- labels += ["nn_macs", "hardware_macs", "nn_tops", "hardware_tops"]
+ labels += ["nn_macs", "nn_tops"]
labels += ["cycles_" + kind.identifier_name() for kind in PassCycles.all()]
@@ -128,10 +127,8 @@ def write_summary_metrics_csv(nng, summary_filename, arch):
]
data_items += [
- nng.macs[MacCount.NeuralNetworkMacs],
- nng.macs[MacCount.HardwareMacs],
- nng.macs[MacCount.NeuralNetworkMacs] * 2 * midpoint_fps / 1e12,
- nng.macs[MacCount.HardwareMacs] * 2 * midpoint_fps / 1e12,
+ nng.macs,
+ nng.macs * 2 * midpoint_fps / 1e12,
]
data_items += [nng.cycles[kind] for kind in PassCycles.all()]
@@ -164,7 +161,6 @@ def write_pass_metrics_csv(nng, pass_filename):
bandwidth_names.append(label)
bandwidth_indices.append((mem_area, purpose_candidates, direction_candidates))
- all_macs = MacCount.all()
all_cycles = (
PassCycles.Total,
PassCycles.Npu,
@@ -183,10 +179,9 @@ def write_pass_metrics_csv(nng, pass_filename):
"block_config_width",
"block_config_input_channels",
"block_config_output_channels",
- "n_blocks_in_pass",
]
+ ["cycles_" + v.identifier_name() for v in all_cycles]
- + [v.identifier_name() for v in all_macs]
+ + ["nn_macs"]
+ bandwidth_names
+ ["sram_used"]
)
@@ -205,9 +200,8 @@ def write_pass_metrics_csv(nng, pass_filename):
stats += [ps.placement.name]
stats += [cps.strategy.name]
stats += list(ps.block_config)
- stats += [ps.n_blocks]
stats += [round_up_to_int(ps.cycles[v]) for v in all_cycles]
- stats += [round_up_to_int(ps.macs[v]) for v in all_macs]
+ stats += [round_up_to_int(ps.macs)]
for indices in bandwidth_indices:
res = 0
i = indices[0]
@@ -256,17 +250,16 @@ def print_performance_metrics_for_strat(
if name:
print("", file=f)
- print("Network summary for", name, file=f)
- print("Accelerator configuration {:>20}".format(arch.accelerator_config.name), file=f)
- print("System configuration {:>20}".format(arch.system_config), file=f)
- print("Memory mode {:>20}".format(arch.memory_mode), file=f)
- print("Accelerator clock {:12d} MHz".format(int(arch.core_clock / 1e6)), file=f)
+ print(f"Network summary for {name}", file=f)
+ print(f"Accelerator configuration {arch.accelerator_config.name:>20}", file=f)
+ print(f"System configuration {arch.system_config:>20}", file=f)
+ print(f"Memory mode {arch.memory_mode:>20}", file=f)
+ print(f"Accelerator clock {int(arch.core_clock / 1e6):12d} MHz", file=f)
for mem_area, label in mem_area_labels:
+ label += " bandwidth"
+ bandwidth = arch.memory_bandwidths_per_second[mem_area] / 1000.0 / 1000 / 1000
print(
- "Design peak {:25} {:12.2f} GB/s".format(
- label + " bandwidth", arch.memory_bandwidths_per_second[mem_area] / 1000.0 / 1000 / 1000
- ),
- file=f,
+ f"Design peak {label:25} {bandwidth:12.2f} GB/s", file=f,
)
print(file=f)
for mem_area, label in mem_area_labels:
@@ -277,12 +270,12 @@ def print_performance_metrics_for_strat(
extra = ""
if (mem_area == MemArea.OnChipFlash or mem_area == MemArea.OffChipFlash) and bits_per_element is not None:
- extra = " ({:.2f} bits per element)".format(bits_per_element[mem_area])
+ extra = f" ({bits_per_element[mem_area]:.2f} bits per element)"
- print("Total {:25} {:12.2f} KiB{}".format(aug_label, memory_used[mem_area] / 1024.0, extra), file=f)
+ print(f"Total {aug_label:25} {memory_used[mem_area] / 1024.0:12.2f} KiB{extra}", file=f)
print(file=f)
- print("{:d} passes fused into {:d}".format(num_passes, num_cascaded_passes), file=f)
+ print(f"{num_passes:d} passes fused into {num_cascaded_passes:d}", file=f)
if cpu_operations is None:
cpu_operations = []
@@ -290,9 +283,8 @@ def print_performance_metrics_for_strat(
n_cpu_operations = len(cpu_operations)
if n_operations > 0:
print(
- "{:d}/{:d} ({:4.1%}) operations falling back to the CPU".format(
- n_cpu_operations, n_operations, n_cpu_operations / n_operations * 100
- ),
+ f"{n_cpu_operations:d}/{n_operations:d}"
+ f" ({n_cpu_operations / n_operations * 100:4.1%}) operations falling back to the CPU",
file=f,
)
@@ -303,9 +295,8 @@ def print_performance_metrics_for_strat(
return " ".join(str(list(tens.shape)) for tens in lst)
print(
- "CPU operation: {} inputs {}, outputs {}".format(
- op.type, format_tens_list(op.inputs), format_tens_list(op.outputs)
- ),
+ f"CPU operation: {op.type}"
+ f" inputs {format_tens_list(op.inputs)}, outputs {format_tens_list(op.outputs)}",
file=f,
)
@@ -318,60 +309,43 @@ def print_performance_metrics_for_strat(
fm_bws = bws[TensorPurpose.FeatureMap]
aug_label = label + " bandwidth"
print(
- "Average {:25} {:12.2f} GB/s".format(aug_label, total_bw * midpoint_fps / 1000.0 / 1000.0 / 1000.0),
- file=f,
+ f"Average {aug_label:25} {total_bw * midpoint_fps / 1000.0 / 1000.0 / 1000.0:12.2f} GB/s", file=f,
)
print(
- "Input {:25} {:12.2f} MB/batch".format(
- aug_label, np.sum(fm_bws[BandwidthDirection.Read]) / 1000.0 / 1000.0
- ),
+ f"Input {aug_label:25} {np.sum(fm_bws[BandwidthDirection.Read]) / 1000.0 / 1000.0:12.2f} MB/batch",
file=f,
)
- print("Weight {:25} {:12.2f} MB/batch".format(aug_label, np.sum(weight_bws) / 1000.0 / 1000.0), file=f)
+ print(f"Weight {aug_label:25} {np.sum(weight_bws) / 1000.0 / 1000.0:12.2f} MB/batch", file=f)
print(
- "Output {:25} {:12.2f} MB/batch".format(
- aug_label, np.sum(fm_bws[BandwidthDirection.Write]) / 1000.0 / 1000.0
- ),
+ f"Output {aug_label:25} "
+ f"{np.sum(fm_bws[BandwidthDirection.Write]) / 1000.0 / 1000.0:12.2f} MB/batch",
file=f,
)
- print("Total {:25} {:12.2f} MB/batch".format(aug_label, total_bw / 1000.0 / 1000.0), file=f)
+ print(f"Total {aug_label:25} {total_bw / 1000.0 / 1000.0:12.2f} MB/batch", file=f)
print(
- "Total {:25} per input {:9.2f} MB/inference (batch size {:d})".format(
- aug_label, total_bw / 1000.0 / 1000.0 / batch_size, batch_size
- ),
+ f"Total {aug_label:25} per input "
+ f"{total_bw / 1000.0 / 1000.0 / batch_size:9.2f} MB/inference (batch size {batch_size:d})",
file=f,
)
print(file=f)
print(
- "Neural network macs {:12d} MACs/batch".format(int(macs[MacCount.NeuralNetworkMacs])),
- file=f,
+ f"Neural network macs {int(macs):12d} MACs/batch", file=f,
)
- print("Hardware macs {:12d} MACs/batch".format(int(macs[MacCount.HardwareMacs])), file=f)
print(
- "Network Tops/s {:12.2f} Tops/s".format(
- macs[MacCount.NeuralNetworkMacs] * 2 * midpoint_fps / 1e12
- ),
- file=f,
- )
- print(
- "Hardware Tops/s {:12.2f} Tops/s".format(
- macs[MacCount.HardwareMacs] * 2 * midpoint_fps / 1e12
- ),
- file=f,
+ f"Network Tops/s {macs * 2 * midpoint_fps / 1e12:12.2f} Tops/s", file=f,
)
print(file=f)
for kind in PassCycles.all():
aug_label = kind.display_name() + " cycles"
cyc = cycles[kind]
- print("{:30} {:12d} cycles/batch".format(aug_label, int(cyc)), file=f)
+ print(f"{aug_label:30} {int(cyc):12d} cycles/batch", file=f)
print(file=f)
print(
- "Batch Inference time {:7.2f} ms, {:7.2f} inferences/s (batch size {:d})".format(
- midpoint_inference_time * 1000, midpoint_fps, batch_size
- ),
+ f"Batch Inference time {midpoint_inference_time * 1000:7.2f} ms,"
+ f" {midpoint_fps:7.2f} inferences/s (batch size {batch_size:d})",
file=f,
)
print(file=f)