aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-11-26 11:42:04 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-07 14:51:52 +0000
commit1e17018d1aabff6b2a4cc5e8e3758678347b84c5 (patch)
tree8c06cb5a9f68e45fce96d9f17aac9a86f28ad912
parent32c7f5bbbccae480a0bb0c0e5b74a37dd9412023 (diff)
downloadethos-u-vela-1e17018d1aabff6b2a4cc5e8e3758678347b84c5.tar.gz
MLBEDSW-3643: Refactor blockdep calculation
Moved blockdep calculation and other helper functions for code generation to a separate file. Change-Id: I2f8ccea478654272ebf42217fc5c1800e9ad177a Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--ethosu/vela/architecture_features.py138
-rw-r--r--ethosu/vela/compiler_driver.py4
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py55
-rw-r--r--ethosu/vela/register_command_stream_generator.py451
-rw-r--r--ethosu/vela/register_command_stream_util.py543
-rw-r--r--ethosu/vela/shared_buffer_allocation.py2
-rw-r--r--ethosu/vela/test/extapi/test_extapi_generate_commands.py2
-rw-r--r--ethosu/vela/test/test_register_command_stream_util.py (renamed from ethosu/vela/test/test_register_command_generator.py)2
8 files changed, 620 insertions, 577 deletions
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index f7dcc8ce..354ab12c 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -192,6 +192,7 @@ class ArchitectureFeatures:
SubKernelMax = Block(8, 8, 65536)
DEFAULT_CONFIG = "internal-default"
+ MAX_BLOCKDEP = 3
def __init__(
self,
@@ -442,143 +443,6 @@ class ArchitectureFeatures:
return Block(ifm_block_width, ifm_block_height, ifm_block_depth)
- @staticmethod
- def intersects(start_a, end_a, start_b, end_b):
- start_x = max(start_a[0], start_b[0])
- end_x = min(end_a[0], end_b[0])
- start_y = max(start_a[1], start_b[1])
- end_y = min(end_a[1], end_b[1])
- start_z = max(start_a[2], start_b[2])
- end_z = min(end_a[2], end_b[2])
- return ((end_x - start_x) > 0) and ((end_y - start_y) > 0) and ((end_z - start_z) > 0)
-
- # Block job dependency:
- # Does the VOLUME of IFMs for block job B(0) overlap with VOLUME of OFMs block jobs A(8,9,10)
- #
- # A | B
- # ----------------------+------------------
- # .... 3,4,5,6,7,8,9,10 | 0,1,2,3,4,5,6,8 10 < JOB NUMBER
- # |<------->| dependency offset
- #
- MAX_BLOCKDEP = 3
-
- # Get the coordinates of a block offset from either the end (negative)
- # or the start (zero or positive) of the given 3d area
- def get_offset_block_coords(self, area: Rect, block: Block, offset):
- size = area.size()
- # Dimensions of the region, in blocks
- width_blocks = round_up_divide(size.width, block.width)
- height_blocks = round_up_divide(size.height, block.height)
- depth_blocks = round_up_divide(size.depth, block.depth)
- total_blocks = width_blocks * height_blocks * depth_blocks
- if offset < 0:
- index = total_blocks + offset
- else:
- index = offset
-
- if index >= total_blocks:
- return None
-
- # Coordinates of the indexed block
- coord_z = block.depth * (index % depth_blocks)
- coord_y = block.height * (index // (depth_blocks * width_blocks))
- coord_x = block.width * ((index // depth_blocks) % width_blocks)
-
- return (coord_x + area.x, coord_y + area.y, coord_z + area.z)
-
- def get_first_job_input_volume(
- self, ifm: Rect, ofm: Rect, ifm_block_depth, ofm_block: Block, kernel: Kernel, padLT, block_offset
- ):
- # Get ifm block size (jobs are invisibly decomposed into subkernels)
- ifm_block = self.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, self.ofm_block_max)
- ifm_depth_blocks = round_up_divide(ifm.size().depth, ifm_block_depth)
-
- # Which OFM block are we calculating
- ofm_coord = self.get_offset_block_coords(ofm, ofm_block, block_offset // ifm_depth_blocks)
- if ofm_coord is None:
- return None
-
- # Coordinate of the source IFM block
- ifm_coord_x = max(0, ofm_coord[0] * kernel.stride.x - padLT[0])
- ifm_coord_y = max(0, ofm_coord[1] * kernel.stride.y - padLT[1])
- ifm_coord_z = ifm.z + (block_offset % ifm_depth_blocks) * ifm_block.depth
-
- # IFM block that will be sampled for the FIRST+block_offset job in the next operator's OFM
- start_coord = (ifm_coord_x, ifm_coord_y, ifm_coord_z)
- end_coord = (
- start_coord[0] + ifm_block.width,
- start_coord[1] + ifm_block.height,
- start_coord[2] + ifm_block.depth,
- )
- return (start_coord, end_coord, 1) # start, end, total jobs
-
- def get_prev_job_output_volume(self, ofm: Rect, ofm_block: Block, block_offset):
- assert block_offset >= 0
-
- # Get OFM block's volume coordinates
- start_coord = self.get_offset_block_coords(ofm, ofm_block, -1 - block_offset)
- if start_coord is None:
- return None
- end_coord = (
- start_coord[0] + ofm_block.width,
- start_coord[1] + ofm_block.height,
- start_coord[2] + ofm_block.depth,
- )
- return (start_coord, end_coord, 1) # start, end, total jobs for this OFM block
-
- def calc_block_dep(
- self,
- prev_ofm: Rect,
- prev_ofm_block: Block,
- ifm: Rect,
- ofm: Rect,
- ifm_block_depth,
- ofm_block: Block,
- kernel: Kernel,
- padLT,
- intersects,
- ):
- blockdep = ArchitectureFeatures.MAX_BLOCKDEP
-
- # Iterate over the next BLOCKDEP inputs, checking to see if a sliding window
- # of IFM area overlaps with any previous OFM block generation.
- elapsed_jobs = 0
- for forward_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
- # This is the IFM block we want to sample from
- in_area = self.get_first_job_input_volume(
- ifm, ofm, ifm_block_depth, ofm_block, kernel, padLT, forward_offset
- )
- if in_area is None:
- break
-
- # Try several previous-OFM blocks in the past (they still might comprise multiple IFM jobs)
- outstanding_jobs = 0
- for block_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
- # This is the OFM block being generated by the previous op
- out_area = self.get_prev_job_output_volume(prev_ofm, prev_ofm_block, block_offset)
- if out_area is None:
- break
-
- # Block dependency is the max number of allowed outstanding jobs
- # in the pipeline. Selected by determining how many jobs occur
- # in between two operators' overlapping OFM->IFM block volumes
- if intersects(in_area[0], in_area[1], out_area[0], out_area[1]):
- break
- # Early exit if no intersections and we've seen enough jobs in the pipeline
- elif outstanding_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
- break
-
- # This OFM had this many jobs (accumulate over multiple OFM blocks)
- outstanding_jobs += out_area[2]
-
- blockdep = min(blockdep, elapsed_jobs + outstanding_jobs)
- elapsed_jobs += in_area[2]
- # Early exit if no intersections and we've seen enough jobs in the pipeline
- if elapsed_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
- break
-
- return blockdep
-
def is_spilling_enabled(self):
"""
Spilling is a feature that allows the Ethos-U to use a dedicated SRAM as a cache for various types of data
diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py
index d17f1e5b..6c7fdc1a 100644
--- a/ethosu/vela/compiler_driver.py
+++ b/ethosu/vela/compiler_driver.py
@@ -20,6 +20,7 @@ import time
from . import extract_npu_subgraphs
from . import graph_optimiser
from . import high_level_command_stream_generator
+from . import high_level_command_to_npu_op
from . import insert_dma
from . import live_range
from . import lut
@@ -27,7 +28,6 @@ from . import mark_tensors
from . import npu_performance
from . import npu_serialisation
from . import pass_packing
-from . import register_command_stream_generator
from . import scheduler
from . import tensor_allocation
from . import weight_compressor
@@ -289,7 +289,7 @@ def compiler_driver(nng, arch, options, scheduler_options):
nng, sg, arch, options.verbose_high_level_command_stream
)
lut.optimize_high_level_cmd_stream(sg, arch)
- register_command_stream_generator.generate_register_command_stream_for_sg(
+ high_level_command_to_npu_op.generate_register_command_stream_for_sg(
nng, sg, arch, options.verbose_register_command_stream
)
scratch_tens, scratch_fast_tens, flash_tens = npu_serialisation.serialise_npu_subgraph_into_tensors(
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index efd8a03d..7db4931d 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -32,7 +32,6 @@ from .api import NpuDmaOperation
from .api import NpuElementWiseOp
from .api import NpuElementWiseOperation
from .api import NpuFeatureMap
-from .api import NpuKernel
from .api import NpuLayout
from .api import NpuOperation
from .api import NpuPadding
@@ -46,15 +45,20 @@ from .api import NpuTileBox
from .architecture_features import ArchitectureFeatures
from .architecture_features import Block
from .data_type import DataType
+from .debug_database import DebugDatabase
from .high_level_command_stream import Box
from .high_level_command_stream import Command
from .high_level_command_stream import CommandType
from .high_level_command_stream import DMA
from .high_level_command_stream import NpuStripe
-from .operation import Kernel
from .operation import NpuBlockType
from .operation import Op
from .operation import Operation
+from .register_command_stream_generator import generate_command_stream
+from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
+from .register_command_stream_util import is_dma_op
+from .register_command_stream_util import to_npu_kernel
+from .register_command_stream_util import UNARY_ELEMWISE_OPS
from .tensor import MemType
from .tensor import Tensor
from .tensor import TensorBlockTraversal
@@ -62,14 +66,10 @@ from .tensor import TensorFormat
from .tensor import TensorPurpose
-unary_elementwise_ops = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,))
-
-
class BasePointerIndex(IntEnum):
WeightTensor = 0 # base address index for the Weight tensor
ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
- Mem2Mem = (1 << 8) | (3 << 0) # base address slot for memory 2 memory transfer
dtype_map = {
@@ -102,20 +102,6 @@ elementwise_op_map = {
}
-def to_npu_kernel(kernel: Kernel) -> NpuKernel:
- """Converts the given internally used kernel object to NpuKernel (of public API)"""
- return NpuKernel(
- kernel.width, kernel.height, kernel.stride.x, kernel.stride.y, kernel.dilation.x, kernel.dilation.y
- )
-
-
-def to_kernel(kernel: Optional[NpuKernel]) -> Kernel:
- """Converts the given public API object to Kernel (used internally)"""
- if kernel is None:
- return Kernel(1, 1)
- return Kernel(kernel.width, kernel.height, kernel.stride_x, kernel.stride_y, kernel.dilation_x, kernel.dilation_y)
-
-
def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
if ifm_shape == []:
# Scalar needs to be in IFM2
@@ -412,7 +398,7 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
assert op.type in elementwise_op_map, f"Unknown elementwise type {op.type}"
elemwise_op = elementwise_op_map[op.type]
npu_op = NpuElementWiseOperation(elemwise_op)
- if elemwise_op not in unary_elementwise_ops:
+ if elemwise_op not in UNARY_ELEMWISE_OPS:
if not ifm_ifm2_correct_order(cmd.ifm_tensor.shape, cmd.ifm2_tensor.shape):
# The scalar/broadcasted feature map has to be the ifm2 tensor so switch the ifms
cmd.ifm_tensor, cmd.ifm2_tensor = cmd.ifm2_tensor, cmd.ifm_tensor
@@ -452,7 +438,7 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
"""Converts the command to NpuDmaOperation"""
src_region = get_region(cmd.in_tensor, arch)
if cmd.out_tensor.purpose == TensorPurpose.LUT:
- dest_region = BasePointerIndex.Mem2Mem
+ dest_region = BASE_PTR_INDEX_MEM2MEM
else:
dest_region = get_region(cmd.out_tensor, arch)
@@ -492,3 +478,28 @@ def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOp
# add a link to the high level command for debugging purposes
npu_op.cmd = cmd
return npu_op
+
+
+def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
+ """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
+ # Convert high level command stream to list of NpuOperation
+ npu_op_list = []
+ npu_op_to_cmd = dict() # map from npu op to high level command
+ for cmd in sg.high_level_command_stream:
+ if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
+ print("Warning: Skipping register command stream generation for", cmd.ps)
+ else:
+ npu_op = convert_command_to_npu_op(cmd, arch)
+ npu_op_list.append(npu_op)
+ npu_op_to_cmd[npu_op] = cmd
+ # Generate register commands
+ stream_id = DebugDatabase.add_stream(sg)
+ DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing
+
+ def add_to_debug_db(npu_op: NpuOperation, offset: int):
+ """Adds info to the debug database"""
+ if not is_dma_op(npu_op):
+ cmd = npu_op_to_cmd[npu_op]
+ DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
+
+ sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db)
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 741b09c1..d4947b1a 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -18,16 +18,13 @@
# all the register settings. Calculates dependencies between commands and inserts wait operations. And generates a bit
# stream suitable for interpretation by the Ethos-U processor.
from collections import defaultdict
-from collections import namedtuple
from enum import Enum
from enum import IntEnum
from typing import List
from typing import Optional
-from typing import Tuple
import numpy as np
-from . import numeric_util
from . import scaling
from .api import NpuAccelerator
from .api import NpuActivation
@@ -57,10 +54,8 @@ from .architecture_features import Accelerator
from .architecture_features import ArchitectureFeatures
from .architecture_features import Block
from .architecture_features import create_default_arch
-from .architecture_features import Rect
from .architecture_features import SharedBufferArea
from .architecture_features import SHRAMElements
-from .debug_database import DebugDatabase
from .ethos_u55_regs.ethos_u55_regs import acc_format
from .ethos_u55_regs.ethos_u55_regs import activation
from .ethos_u55_regs.ethos_u55_regs import cmd0
@@ -69,17 +64,20 @@ from .ethos_u55_regs.ethos_u55_regs import elementwise_mode
from .ethos_u55_regs.ethos_u55_regs import pooling_mode
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .ethos_u55_regs.ethos_u55_regs import rounding
-from .high_level_command_stream import CommandType
-from .high_level_command_to_npu_op import convert_command_to_npu_op
-from .high_level_command_to_npu_op import to_kernel
-from .high_level_command_to_npu_op import unary_elementwise_ops
from .numeric_util import quantise_float32
from .numeric_util import round_away_zero
from .numeric_util import round_up_to_int
from .operation import NpuBlockType
-from .range_set import AccessDirection
-from .range_set import MemoryAccessSet
-from .range_set import MemoryRangeSet
+from .register_command_stream_util import calc_blockdep
+from .register_command_stream_util import get_dma_memory_accesses
+from .register_command_stream_util import get_op_memory_accesses
+from .register_command_stream_util import get_strides
+from .register_command_stream_util import get_wait_dependency
+from .register_command_stream_util import has_ifm2
+from .register_command_stream_util import is_dma_op
+from .register_command_stream_util import to_kernel
+from .register_command_stream_util import UNARY_ELEMWISE_OPS
+from .register_command_stream_util import Watermark
from .shared_buffer_allocation import find_suitable_block_configs
from .shared_buffer_allocation import shared_buffer_allocation_for_npu_op
from .shared_buffer_allocation import SharedBufferAllocation
@@ -203,13 +201,6 @@ class CommandStreamEmitter:
# -------------------------------------------------------------------
-class BasePointerIndex(IntEnum):
- WeightTensor = 0 # base address index for the Weight tensor
- ScratchTensor = 1 # base address index for the Scratch_tensor in the TensorArena
- ScratchFastTensor = 2 # base address for the Scratch_fast_tensor
- Mem2Mem = (1 << 8) | (3 << 0) # base address slot for memory 2 memory transfer
-
-
# TODO: Replace with definitions from ethos_u55_regs
class IFM2Broadcast(IntEnum):
BroadcastHdim = 1 << 0
@@ -275,16 +266,6 @@ def quantise(value: float, quant: Optional[NpuQuantization]) -> int:
return quantise_float32(value, scale, zp)
-def has_ifm2(npu_op: NpuBlockOperation) -> bool:
- """Checks if op has non-scalar IFM2"""
- return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None
-
-
-def is_dma_op(npu_op: NpuOperation) -> bool:
- """Checks if op is a DMA operation"""
- return npu_op.op_type == NpuOperationType.Dma
-
-
def generate_padding(emit: CommandStreamEmitter, padding: NpuPadding):
"""Generates IFM_PAD registers"""
emit.cmd0_with_param(cmd0.NPU_SET_IFM_PAD_TOP, padding.top)
@@ -584,6 +565,15 @@ def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures)
return shared_buffer_allocation_for_npu_op(arch, npu_op, block_type, ifm_resampling_mode)
+def generate_cmd_waits(emit: CommandStreamEmitter, cmd_waits: Watermark):
+ """Generates KERNEL_WAIT/DMA_WAIT"""
+ if cmd_waits.npu >= 0:
+ emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, 0, cmd_waits.npu)
+
+ if cmd_waits.dma >= 0:
+ emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, 0, cmd_waits.dma)
+
+
def generate_common(
emit: CommandStreamEmitter,
npu_op: NpuBlockOperation,
@@ -735,353 +725,6 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem
# -------------------------------------------------------------------
-# ADDRESSING/STRIDES (helper functions)
-# -------------------------------------------------------------------
-
-
-def ranges_overlap(range1: NpuAddressRange, range2: NpuAddressRange) -> bool:
- """Checks if the ranges overlap"""
- return range1.region == range2.region and numeric_util.overlaps(
- range1.address, range1.address + range1.length, range2.address, range2.address + range2.length
- )
-
-
-def range_lists_overlap(list1: List[Optional[NpuAddressRange]], list2: List[Optional[NpuAddressRange]]) -> bool:
- """Checks if there is any address overlap between list1 and list2"""
- for range1 in list1:
- if range1 is None:
- continue
- for range2 in list2:
- if range2 is not None and ranges_overlap(range1, range2):
- return True
- return False
-
-
-def get_strides(fm: NpuFeatureMap) -> NpuShape3D:
- """Calculates STRIDE_C/Y/X"""
- if fm.strides is not None:
- return fm.strides
- elem_size = fm.data_type.size_in_bytes()
- if fm.layout == NpuLayout.NHWC:
- stride_c = elem_size
- stride_x = fm.shape.depth * stride_c
- stride_y = fm.shape.width * stride_x
- else:
- stride_x = 16 * elem_size
- stride_c = stride_x * fm.shape.width
- stride_y = elem_size * fm.shape.width * numeric_util.round_up(fm.shape.depth, 16)
- return NpuShape3D(depth=stride_c, height=stride_y, width=stride_x)
-
-
-def get_address(fm: NpuFeatureMap, strides: NpuShape3D, y: int, x: int, c: int) -> int:
- """Returns address of given coordinate"""
- t = 0
- BRICK = 16
- stride_c = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHWC else strides.depth
- stride_x = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHCWB16 else strides.width
- if x >= fm.tiles.width_0:
- x -= fm.tiles.width_0
- t = 1
- if y >= fm.tiles.height_1:
- y -= fm.tiles.height_1
- t += 2
- elif y >= fm.tiles.height_0:
- y -= fm.tiles.height_0
- t += 2
- elem_size = fm.data_type.size_in_bytes()
- return (
- fm.tiles.addresses[t] + y * strides.height + x * stride_x + (c // BRICK) * stride_c + int(c % BRICK) * elem_size
- )
-
-
-def get_address_range(
- fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
-) -> NpuAddressRange:
- """
- Gets address range for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm).
- The begin and end coordinates must be within the same tile.
- """
- addr0 = get_address(fm, strides, y0, x0, c0)
- addr1 = get_address(fm, strides, y1, x1, c1)
- return NpuAddressRange(region=fm.region, address=addr0, length=addr1 - addr0 + fm.data_type.size_in_bytes())
-
-
-def get_h_ranges(
- fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
-) -> List[NpuAddressRange]:
- """
- Gets address ranges for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm);
- the begin and end coordinates must be within the same tile.
- Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
- """
- return [get_address_range(fm, strides, y, x0, c0, y, x1, c1) for y in range(y0, y1 + 1)]
-
-
-def get_address_ranges_for_area(
- fm: NpuFeatureMap, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
-) -> List[NpuAddressRange]:
- """
- Returns a list of adddress ranges that covers the area (y0, x0, c0) - (y1, x1, c1) (inclusive).
- Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
-
- For example, for the area marked with X (in a feature map with 4 tiles) as input, this function would return
- 6 address ranges: the address ranges for 1-height areas [AAA, BBB, CC, DD, EEE, FF]
-
- .....|.... .....|....
- t0 ..XXX|XX.. t1 t0 ..AAA|CC.. t1
- ..XXX|XX.. ..BBB|DD..
- -----+---- --> -----+----
- t2 ..XXX|XX.. t3 t2 ..EEE|FF.. t3
- .....|.... .....|....
- """
- strides = get_strides(fm)
- height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
- h, w, c = fm.shape
- y2, x2, c2 = min(y1, h - 1), min(x1, w - 1), min(c1, c - 1)
- ranges = []
- if x0 < width_0 and y0 < height_0:
- # Horizontal ranges for tile 0
- ranges.extend(get_h_ranges(fm, strides, y0, x0, c0, min(y2, height_0 - 1), min(x2, width_0 - 1), c2))
- if x2 >= width_0 and y0 < height_1:
- # Horizontal ranges for tile 1
- ranges.extend(get_h_ranges(fm, strides, y0, max(x0, width_0), c0, min(y2, height_1 - 1), x2, c2))
- if x0 < width_0 and y2 >= height_0:
- # Horizontal ranges for tile 2
- ranges.extend(get_h_ranges(fm, strides, max(y0, height_0), x0, c0, y2, min(x2, width_0 - 1), c2))
- if x2 >= width_0 and y2 >= height_1:
- # Horizontal ranges for tile 3
- ranges.extend(get_h_ranges(fm, strides, max(y0, height_1), max(x0, width_0), c0, y2, x2, c2))
- return ranges
-
-
-def get_address_ranges(fm: NpuFeatureMap) -> List[Optional[NpuAddressRange]]:
- """Returns 4 adddress ranges, one for every tile, None if the tile is not in use"""
- strides = get_strides(fm)
- height, width, depth = fm.shape.height, fm.shape.width, fm.shape.depth
- height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
- t0 = get_address_range(fm, strides, 0, 0, 0, min(height, height_0) - 1, min(width, width_0) - 1, depth - 1,)
- if width > width_0:
- t1 = get_address_range(fm, strides, 0, width_0, 0, min(height, height_1) - 1, width - 1, depth - 1)
- else:
- t1 = None
- if height > height_0:
- t2 = get_address_range(fm, strides, height_0, 0, 0, height - 1, min(width, width_0) - 1, depth - 1)
- else:
- t2 = None
- if t1 is not None and t2 is not None:
- t3 = get_address_range(fm, strides, height_1, width_0, 0, height - 1, width - 1, depth - 1)
- else:
- t3 = None
- return [t0, t1, t2, t3]
-
-
-# -------------------------------------------------------------------
-# DMA_WAIT/KERNEL_WAIT
-# -------------------------------------------------------------------
-
-
-Watermark = namedtuple("Watermark", ["npu", "dma"])
-
-
-def memory_range_set(range: NpuAddressRange) -> MemoryRangeSet:
- return MemoryRangeSet(range.region, range.address, range.address + range.length)
-
-
-def get_dma_memory_accesses(dma_op: NpuDmaOperation) -> MemoryAccessSet:
- """Returns the address that are read and written by the given DMA operation"""
- res = MemoryAccessSet()
- res.add(memory_range_set(dma_op.src), AccessDirection.Read)
- res.add(memory_range_set(dma_op.dest), AccessDirection.Write)
- return res
-
-
-def get_op_memory_accesses(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> MemoryAccessSet:
- """Returns the addresses that are read and written by the given operation"""
- assert npu_op.ifm is not None and npu_op.ofm is not None
- # Read addresses
- read_ranges = get_address_ranges(npu_op.ifm)
- if has_ifm2(npu_op):
- assert npu_op.ifm2 is not None
- read_ranges.extend(get_address_ranges(npu_op.ifm2))
- read_ranges.extend(npu_op.weights)
- read_ranges.extend(npu_op.biases)
- if npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP:
- address = arch.available_shram_banks(True) * arch.shram_bank_size
- read_ranges.append(NpuAddressRange(region=BasePointerIndex.Mem2Mem, address=address, length=2048))
- # Written addresses
- write_ranges = get_address_ranges(npu_op.ofm)
- # Add write access to SHRAM, needed when LUTs can overwrite accumulator banks
- uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP
- written_shram_size = arch.available_shram_banks(uses_lut) * arch.shram_bank_size
- write_ranges.append(NpuAddressRange(region=BasePointerIndex.Mem2Mem, address=0, length=written_shram_size))
-
- res = MemoryAccessSet()
- for read_range in read_ranges:
- if read_range is not None:
- res.add(memory_range_set(read_range), AccessDirection.Read)
- for write_range in write_ranges:
- if write_range is not None:
- res.add(memory_range_set(write_range), AccessDirection.Write)
- return res
-
-
-def get_wait_dependency(
- arch: ArchitectureFeatures, npu_op_list: List[NpuOperation], memory_accesses, op_index: int, watermark: Watermark
-):
- """Used to calculate whether DMA wait or kernel wait operations are needed"""
- npu_op = npu_op_list[op_index]
- op_access = memory_accesses[npu_op]
- index = op_index - 1
-
- # NPU dependency tracking
- npu_outstanding = -1
- npu_ops = 0
- npu_index = watermark.npu
-
- # DMA dependency tracking
- dma_outstanding = -1
- dma_ops = 0
- dma_index = watermark.dma
-
- # Seek back in the command stream looking for NPU or DMA dependencies
- # but only as far as the first dependency or the watermarks (dependencies
- # before this point have been satisfied already).
- # The watermark moves to after the latest element we must wait for, not
- # the command that issues the wait.
- # NPU->NPU dependency is handled via blockdep.
- while (index >= npu_index) or (index >= dma_index):
- prev_op = npu_op_list[index]
- prev_access = memory_accesses[prev_op]
-
- # Check NPU consuming DMA output
- if is_dma_op(prev_op):
- if index >= dma_index:
- if not is_dma_op(npu_op):
- if (dma_outstanding == -1) and prev_access.conflicts(op_access):
- dma_outstanding = dma_ops
- dma_ops += 1 # Count DMA ops in the pipeline
- if dma_ops >= arch.max_outstanding_dma:
- dma_index = max(index + 1, dma_index)
- # Check DMA consuming NPU output
- else:
- if index >= npu_index:
- if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access):
- npu_outstanding = npu_ops
- npu_ops += 1 # Count NPU ops in the pipeline
- if npu_ops >= arch.max_outstanding_kernels:
- npu_index = max(index + 1, npu_index)
-
- index -= 1
-
- # Update DMA watermark if we didn't see any and the NPU pipeline is full
- if (dma_ops == 0) and (npu_ops >= arch.max_outstanding_kernels):
- dma_index = op_index
-
- # Bring the search watermark forwards as we complete for those dependencies
- watermark = Watermark(npu_index, dma_index)
- outstanding = Watermark(npu_outstanding, dma_outstanding)
-
- return watermark, outstanding
-
-
-def generate_cmd_waits(emit: CommandStreamEmitter, cmd_waits: Watermark):
- if cmd_waits.npu >= 0:
- emit.cmd_wait(cmd0.NPU_OP_KERNEL_WAIT, 0, cmd_waits.npu)
-
- if cmd_waits.dma >= 0:
- emit.cmd_wait(cmd0.NPU_OP_DMA_WAIT, 0, cmd_waits.dma)
-
-
-# -------------------------------------------------------------------
-# BLOCKDEP
-# -------------------------------------------------------------------
-
-
-def shape3d_size(shape: NpuShape3D) -> int:
- return shape.width * shape.height * shape.depth
-
-
-def shape3d_to_rect(shape: NpuShape3D) -> Rect:
- return Rect(0, 0, 0, shape.width - 1, shape.height - 1, shape.depth - 1)
-
-
-def get_ifm_ofm_block_depth(arch: ArchitectureFeatures, npu_op: NpuBlockOperation) -> int:
- # Note: NOT equivalent to the normal ifm block depth calculation since
- # it takes into account 'depthless' block operations by returning full
- # depth
- if npu_op.op_type == NpuOperationType.Conv2D:
- res = arch.calc_ifm_block_depth(npu_op.ifm.shape.depth, npu_op.ifm.data_type.size_in_bits())
- return res
- return npu_op.ofm.shape.depth
-
-
-def calc_blockdep(arch: ArchitectureFeatures, prev_op: Optional[NpuBlockOperation], npu_op: NpuBlockOperation,) -> int:
- """Calculates the value of the BLOCKDEP register"""
- if prev_op is None:
- return 0
- assert npu_op.ifm is not None
- assert prev_op.ofm is not None
- # Check if IFM or IFM2 overlaps with prev op's OFM
- prev_ofm_ranges = get_address_ranges(prev_op.ofm)
- ifm_ranges = get_address_ranges(npu_op.ifm)
- ifm_overlaps = range_lists_overlap(prev_ofm_ranges, ifm_ranges)
- if has_ifm2(npu_op):
- assert npu_op.ifm2 is not None
- ifm2_ranges = get_address_ranges(npu_op.ifm2)
- ifm2_overlaps = range_lists_overlap(prev_ofm_ranges, ifm2_ranges)
- else:
- ifm2_overlaps = False
- if ifm_overlaps and ifm2_overlaps:
- # Both IFM and IFM2 overlap (should be rare)
- return 0
- if not ifm_overlaps and not ifm2_overlaps:
- # No overlap between prev OFM and IFM/IFM2
- return ArchitectureFeatures.MAX_BLOCKDEP
- if ifm2_overlaps and shape3d_size(npu_op.ifm2.shape) < shape3d_size(npu_op.ifm.shape):
- # Prev OFM produces IFM2 which is broadcasted (this should be rare)
- return 0
- prev_block_config = prev_op.block_config
- block_config = npu_op.block_config
- overlapping_fm = npu_op.ifm if ifm_overlaps else npu_op.ifm2
- assert overlapping_fm is not None
-
- def intersects(ifm_start_coord: Tuple, ifm_end_coord: Tuple, ofm_start_coord: Tuple, ofm_end_coord: Tuple) -> bool:
- """Checks if the given IFM area overlaps with the given OFM area"""
- if overlapping_fm.shape == prev_op.ofm.shape and overlapping_fm.tiles == prev_op.ofm.tiles:
- # Common case: prev_op.ofm == op.ifm; in this case it suffices to check
- # if the xyz coordinates overlap, which is quick and easy
- return ArchitectureFeatures.intersects(ifm_start_coord, ifm_end_coord, ofm_start_coord, ofm_end_coord)
- # The OFM produces a part of the IFM (e.g. a stripe), or the IFM consumes part of the OFM.
- # In this case address comparison is needed between the two areas
- x0, y0, c0 = ifm_start_coord
- x1, y1, c1 = ifm_end_coord
- ifm_ranges = get_address_ranges_for_area(overlapping_fm, y0, x0, c0, y1, x1, c1)
- x0, y0, c0 = ofm_start_coord
- x1, y1, c1 = ofm_end_coord
- prev_ofm_ranges = get_address_ranges_for_area(prev_op.ofm, y0, x0, c0, y1, x1, c1)
- return range_lists_overlap(ifm_ranges, prev_ofm_ranges)
-
- prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth)
- prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape)
- cur_ifm_block_depth = get_ifm_ofm_block_depth(arch, npu_op)
- cur_ofm_block = Block(block_config.width, block_config.height, block_config.depth)
- cur_ofm_rect = shape3d_to_rect(npu_op.ofm.shape)
- cur_ifm_rect = shape3d_to_rect(npu_op.ifm.shape)
- cur_padLT = (0, 0) if npu_op.padding is None else (npu_op.padding.left, npu_op.padding.top)
- return arch.calc_block_dep(
- prev_ofm_rect,
- prev_ofm_block,
- cur_ifm_rect,
- cur_ofm_rect,
- cur_ifm_block_depth,
- cur_ofm_block,
- to_kernel(npu_op.kernel),
- cur_padLT,
- intersects=intersects,
- )
-
-
-# -------------------------------------------------------------------
# PRINT
# -------------------------------------------------------------------
@@ -1209,7 +852,7 @@ def generate_elementwise_op(emit: CommandStreamEmitter, npu_op: NpuElementWiseOp
emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch, use_global_scale=use_global_scale, op_to_scale=op_to_scale
)
# Elementwise op specific
- if npu_op.sub_op_type not in unary_elementwise_ops:
+ if npu_op.sub_op_type not in UNARY_ELEMWISE_OPS:
# Binary operation; generate IFM2 registers
assert npu_op.ifm2 is not None
has_scalar = npu_op.ifm2_scalar is not None
@@ -1253,9 +896,15 @@ def generate_registers_for_op(emit: CommandStreamEmitter, npu_op: NpuOperation,
def generate_command_stream(
- emit: CommandStreamEmitter, npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, add_to_debug_db=None
-):
- """Generates register commands for the given list of NPU operations"""
+ npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None,
+) -> List[int]:
+ """
+ Generates register commands for the given list of NPU operations.
+ Returns Ethos-U instructions, as a list of 32-bit integers.
+ """
+ emit = CommandStreamEmitter()
+ if verbose:
+ print_operations(npu_op_list)
# Calculate memory accesses for every operation
memory_accesses = {}
for npu_op in npu_op_list:
@@ -1285,39 +934,17 @@ def generate_command_stream(
add_to_debug_db(npu_op, emit.offset)
# Fill in final part of command stream:
emit.cmd_do_operation(cmd0.NPU_OP_STOP, param=0xFFFF)
-
-
-def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
- """Generates command stream for the subgraph, adds it to sg.register_command_stream"""
- # Convert high level command stream to list of NpuOperation
- npu_op_list = []
- npu_op_to_cmd = dict() # map from npu op to high level command
- for cmd in sg.high_level_command_stream:
- if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
- print("Warning: Skipping register command stream generation for", cmd.ps)
- else:
- npu_op = convert_command_to_npu_op(cmd, arch)
- npu_op_list.append(npu_op)
- npu_op_to_cmd[npu_op] = cmd
- if verbose:
- print_operations(npu_op_list)
- # Generate register commands
- stream_id = DebugDatabase.add_stream(sg)
- DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing
- emit = CommandStreamEmitter()
-
- def add_to_debug_db(npu_op: NpuOperation, offset: int):
- """Adds info to the debug database"""
- if not is_dma_op(npu_op):
- cmd = npu_op_to_cmd[npu_op]
- DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
-
- generate_command_stream(emit, npu_op_list, arch, add_to_debug_db)
- sg.register_command_stream = emit.to_list()
+ res = emit.to_list()
if verbose:
emit.print_cmds()
print("number of commands", len(emit.cmd_stream))
- print("command stream length in words", len(sg.register_command_stream))
+ print("command stream length in words", len(res))
+ return res
+
+
+# -------------------------------------------------------------------
+# EXTERNAL API
+# -------------------------------------------------------------------
def find_block_configs(npu_op: NpuOperation, npu_accelerator: NpuAccelerator) -> List[NpuShape3D]:
@@ -1342,7 +969,5 @@ def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accele
:return Ethos-U instructions, as a list of 32-bit integers
"""
accelerator = Accelerator.from_npu_accelerator(npu_accelerator)
- emit = CommandStreamEmitter()
arch = create_default_arch(accelerator)
- generate_command_stream(emit, npu_op_list, arch)
- return emit.to_list()
+ return generate_command_stream(npu_op_list, arch, verbose=False)
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
new file mode 100644
index 00000000..ca7e6bc6
--- /dev/null
+++ b/ethosu/vela/register_command_stream_util.py
@@ -0,0 +1,543 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Description:
+# Utility functions for code generation
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+
+from . import numeric_util
+from .api import NpuActivationOp
+from .api import NpuAddressRange
+from .api import NpuBlockOperation
+from .api import NpuDmaOperation
+from .api import NpuElementWiseOp
+from .api import NpuFeatureMap
+from .api import NpuKernel
+from .api import NpuLayout
+from .api import NpuOperation
+from .api import NpuOperationType
+from .api import NpuPadding
+from .api import NpuShape3D
+from .architecture_features import ArchitectureFeatures
+from .architecture_features import Block
+from .architecture_features import Rect
+from .operation import Kernel
+from .operation import PointXYZ
+from ethosu.vela.range_set import AccessDirection
+from ethosu.vela.range_set import MemoryAccessSet
+from ethosu.vela.range_set import MemoryRangeSet
+
+# base address slot for memory to memory transfer
+BASE_PTR_INDEX_MEM2MEM = int((1 << 8) | (3 << 0))
+
+
+UNARY_ELEMWISE_OPS = set((NpuElementWiseOp.ABS, NpuElementWiseOp.LRELU, NpuElementWiseOp.CLZ,))
+
+
+def to_npu_kernel(kernel: Kernel) -> NpuKernel:
+ """Converts the given internally used kernel object to NpuKernel (of public API)"""
+ return NpuKernel(
+ kernel.width, kernel.height, kernel.stride.x, kernel.stride.y, kernel.dilation.x, kernel.dilation.y
+ )
+
+
+def to_kernel(kernel: Optional[NpuKernel]) -> Kernel:
+ """Converts the given public API object to Kernel (used internally)"""
+ if kernel is None:
+ return Kernel(1, 1)
+ return Kernel(kernel.width, kernel.height, kernel.stride_x, kernel.stride_y, kernel.dilation_x, kernel.dilation_y)
+
+
+def has_ifm2(npu_op: NpuBlockOperation) -> bool:
+ """Checks if op has non-scalar IFM2"""
+ return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None
+
+
+def is_dma_op(npu_op: NpuOperation) -> bool:
+ """Checks if op is a DMA operation"""
+ return npu_op.op_type == NpuOperationType.Dma
+
+
+def shape3d_size(shape: NpuShape3D) -> int:
+ return shape.width * shape.height * shape.depth
+
+
+def shape3d_to_rect(shape: NpuShape3D) -> Rect:
+ return Rect(0, 0, 0, shape.width - 1, shape.height - 1, shape.depth - 1)
+
+
+# -------------------------------------------------------------------
+# ADDRESSING/STRIDES (helper functions)
+# -------------------------------------------------------------------
+
+
+def ranges_overlap(range1: NpuAddressRange, range2: NpuAddressRange) -> bool:
+ """Checks if the ranges overlap"""
+ return range1.region == range2.region and numeric_util.overlaps(
+ range1.address, range1.address + range1.length, range2.address, range2.address + range2.length
+ )
+
+
+def range_lists_overlap(list1: List[Optional[NpuAddressRange]], list2: List[Optional[NpuAddressRange]]) -> bool:
+ """Checks if there is any address overlap between list1 and list2"""
+ for range1 in list1:
+ if range1 is None:
+ continue
+ for range2 in list2:
+ if range2 is not None and ranges_overlap(range1, range2):
+ return True
+ return False
+
+
+def get_strides(fm: NpuFeatureMap) -> NpuShape3D:
+ """Calculates STRIDE_C/Y/X"""
+ if fm.strides is not None:
+ return fm.strides
+ elem_size = fm.data_type.size_in_bytes()
+ if fm.layout == NpuLayout.NHWC:
+ stride_c = elem_size
+ stride_x = fm.shape.depth * stride_c
+ stride_y = fm.shape.width * stride_x
+ else:
+ stride_x = 16 * elem_size
+ stride_c = stride_x * fm.shape.width
+ stride_y = elem_size * fm.shape.width * numeric_util.round_up(fm.shape.depth, 16)
+ return NpuShape3D(depth=stride_c, height=stride_y, width=stride_x)
+
+
+def get_address(fm: NpuFeatureMap, strides: NpuShape3D, y: int, x: int, c: int) -> int:
+ """Returns address of given coordinate"""
+ t = 0
+ BRICK = 16
+ stride_c = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHWC else strides.depth
+ stride_x = BRICK * fm.data_type.size_in_bytes() if fm.layout == NpuLayout.NHCWB16 else strides.width
+ if x >= fm.tiles.width_0:
+ x -= fm.tiles.width_0
+ t = 1
+ if y >= fm.tiles.height_1:
+ y -= fm.tiles.height_1
+ t += 2
+ elif y >= fm.tiles.height_0:
+ y -= fm.tiles.height_0
+ t += 2
+ elem_size = fm.data_type.size_in_bytes()
+ return (
+ fm.tiles.addresses[t] + y * strides.height + x * stride_x + (c // BRICK) * stride_c + int(c % BRICK) * elem_size
+ )
+
+
+def get_address_range(
+ fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
+) -> NpuAddressRange:
+ """
+ Gets address range for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm).
+ The begin and end coordinates must be within the same tile.
+ """
+ addr0 = get_address(fm, strides, y0, x0, c0)
+ addr1 = get_address(fm, strides, y1, x1, c1)
+ return NpuAddressRange(region=fm.region, address=addr0, length=addr1 - addr0 + fm.data_type.size_in_bytes())
+
+
+def get_h_ranges(
+ fm: NpuFeatureMap, strides: NpuShape3D, y0: int, x0: int, c0: int, y1: int, x1: int, c1: int
+) -> List[NpuAddressRange]:
+ """
+ Gets address ranges for (y0, x0, c0) - (y1, x1, c1) (inclusive, so the second coordinate is within the fm);
+ the begin and end coordinates must be within the same tile.
+ Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
+ """
+ return [get_address_range(fm, strides, y, x0, c0, y, x1, c1) for y in range(y0, y1 + 1)]
+
+
+def get_address_ranges_for_area(fm: NpuFeatureMap, start: PointXYZ, end: PointXYZ) -> List[NpuAddressRange]:
+ """
+ Returns a list of adddress ranges that covers the area start - end (inclusive).
+ Divides the area in horizontal "stripes" of height 1, and returns the address ranges for these "stripes".
+
+ For example, for the area marked with X (in a feature map with 4 tiles) as input, this function would return
+ 6 address ranges: the address ranges for 1-height areas [AAA, BBB, CC, DD, EEE, FF]
+
+ .....|.... .....|....
+ t0 ..XXX|XX.. t1 t0 ..AAA|CC.. t1
+ ..XXX|XX.. ..BBB|DD..
+ -----+---- --> -----+----
+ t2 ..XXX|XX.. t3 t2 ..EEE|FF.. t3
+ .....|.... .....|....
+ """
+ strides = get_strides(fm)
+ height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
+ h, w, c = fm.shape
+ y0, x0, c0 = start.y, start.x, start.z
+ y1, x1, c1 = min(end.y, h - 1), min(end.x, w - 1), min(end.z, c - 1)
+ ranges = []
+ if x0 < width_0 and y0 < height_0:
+ # Horizontal ranges for tile 0
+ ranges.extend(get_h_ranges(fm, strides, y0, x0, c0, min(y1, height_0 - 1), min(x1, width_0 - 1), c1))
+ if x1 >= width_0 and y0 < height_1:
+ # Horizontal ranges for tile 1
+ ranges.extend(get_h_ranges(fm, strides, y0, max(x0, width_0), c0, min(y1, height_1 - 1), x1, c1))
+ if x0 < width_0 and y1 >= height_0:
+ # Horizontal ranges for tile 2
+ ranges.extend(get_h_ranges(fm, strides, max(y0, height_0), x0, c0, y1, min(x1, width_0 - 1), c1))
+ if x1 >= width_0 and y1 >= height_1:
+ # Horizontal ranges for tile 3
+ ranges.extend(get_h_ranges(fm, strides, max(y0, height_1), max(x0, width_0), c0, y1, x1, c1))
+ return ranges
+
+
+def get_address_ranges(fm: NpuFeatureMap) -> List[Optional[NpuAddressRange]]:
+ """Returns 4 adddress ranges, one for every tile, None if the tile is not in use"""
+ strides = get_strides(fm)
+ height, width, depth = fm.shape.height, fm.shape.width, fm.shape.depth
+ height_0, height_1, width_0 = fm.tiles.height_0, fm.tiles.height_1, fm.tiles.width_0
+ t0 = get_address_range(fm, strides, 0, 0, 0, min(height, height_0) - 1, min(width, width_0) - 1, depth - 1,)
+ if width > width_0:
+ t1 = get_address_range(fm, strides, 0, width_0, 0, min(height, height_1) - 1, width - 1, depth - 1)
+ else:
+ t1 = None
+ if height > height_0:
+ t2 = get_address_range(fm, strides, height_0, 0, 0, height - 1, min(width, width_0) - 1, depth - 1)
+ else:
+ t2 = None
+ if t1 is not None and t2 is not None:
+ t3 = get_address_range(fm, strides, height_1, width_0, 0, height - 1, width - 1, depth - 1)
+ else:
+ t3 = None
+ return [t0, t1, t2, t3]
+
+
+# -------------------------------------------------------------------
+# DMA_WAIT/KERNEL_WAIT
+# -------------------------------------------------------------------
+
+
+class Watermark(NamedTuple):
+ npu: int
+ dma: int
+
+
+def memory_range_set(range: NpuAddressRange) -> MemoryRangeSet:
+ return MemoryRangeSet(range.region, range.address, range.address + range.length)
+
+
+def get_dma_memory_accesses(dma_op: NpuDmaOperation) -> MemoryAccessSet:
+ """Returns the address that are read and written by the given DMA operation"""
+ res = MemoryAccessSet()
+ res.add(memory_range_set(dma_op.src), AccessDirection.Read)
+ res.add(memory_range_set(dma_op.dest), AccessDirection.Write)
+ return res
+
+
+def get_op_memory_accesses(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> MemoryAccessSet:
+ """Returns the addresses that are read and written by the given operation"""
+ assert npu_op.ifm is not None and npu_op.ofm is not None
+ # Read addresses
+ read_ranges = get_address_ranges(npu_op.ifm)
+ if has_ifm2(npu_op):
+ assert npu_op.ifm2 is not None
+ read_ranges.extend(get_address_ranges(npu_op.ifm2))
+ read_ranges.extend(npu_op.weights)
+ read_ranges.extend(npu_op.biases)
+ if npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP:
+ address = arch.available_shram_banks(True) * arch.shram_bank_size
+ read_ranges.append(NpuAddressRange(region=BASE_PTR_INDEX_MEM2MEM, address=address, length=2048))
+ # Written addresses
+ write_ranges = get_address_ranges(npu_op.ofm)
+ # Add write access to SHRAM, needed when LUTs can overwrite accumulator banks
+ uses_lut = npu_op.activation is not None and npu_op.activation.op_type == NpuActivationOp.TABLE_LOOKUP
+ written_shram_size = arch.available_shram_banks(uses_lut) * arch.shram_bank_size
+ write_ranges.append(NpuAddressRange(region=BASE_PTR_INDEX_MEM2MEM, address=0, length=written_shram_size))
+
+ res = MemoryAccessSet()
+ for read_range in read_ranges:
+ if read_range is not None:
+ res.add(memory_range_set(read_range), AccessDirection.Read)
+ for write_range in write_ranges:
+ if write_range is not None:
+ res.add(memory_range_set(write_range), AccessDirection.Write)
+ return res
+
+
+def get_wait_dependency(
+ arch: ArchitectureFeatures, npu_op_list: List[NpuOperation], memory_accesses, op_index: int, watermark: Watermark
+):
+ """Used to calculate whether DMA wait or kernel wait operations are needed"""
+ npu_op = npu_op_list[op_index]
+ op_access = memory_accesses[npu_op]
+ index = op_index - 1
+
+ # NPU dependency tracking
+ npu_outstanding = -1
+ npu_ops = 0
+ npu_index = watermark.npu
+
+ # DMA dependency tracking
+ dma_outstanding = -1
+ dma_ops = 0
+ dma_index = watermark.dma
+
+ # Seek back in the command stream looking for NPU or DMA dependencies
+ # but only as far as the first dependency or the watermarks (dependencies
+ # before this point have been satisfied already).
+ # The watermark moves to after the latest element we must wait for, not
+ # the command that issues the wait.
+ # NPU->NPU dependency is handled via blockdep.
+ while (index >= npu_index) or (index >= dma_index):
+ prev_op = npu_op_list[index]
+ prev_access = memory_accesses[prev_op]
+
+ # Check NPU consuming DMA output
+ if is_dma_op(prev_op):
+ if index >= dma_index:
+ if not is_dma_op(npu_op):
+ if (dma_outstanding == -1) and prev_access.conflicts(op_access):
+ dma_outstanding = dma_ops
+ dma_ops += 1 # Count DMA ops in the pipeline
+ if dma_ops >= arch.max_outstanding_dma:
+ dma_index = max(index + 1, dma_index)
+ # Check DMA consuming NPU output
+ else:
+ if index >= npu_index:
+ if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access):
+ npu_outstanding = npu_ops
+ npu_ops += 1 # Count NPU ops in the pipeline
+ if npu_ops >= arch.max_outstanding_kernels:
+ npu_index = max(index + 1, npu_index)
+
+ index -= 1
+
+ # Update DMA watermark if we didn't see any and the NPU pipeline is full
+ if (dma_ops == 0) and (npu_ops >= arch.max_outstanding_kernels):
+ dma_index = op_index
+
+ # Bring the search watermark forwards as we complete for those dependencies
+ watermark = Watermark(npu_index, dma_index)
+ outstanding = Watermark(npu_outstanding, dma_outstanding)
+
+ return watermark, outstanding
+
+
+# -------------------------------------------------------------------
+# BLOCKDEP
+# -------------------------------------------------------------------
+
+
+def get_ifm_ofm_block_depth(arch: ArchitectureFeatures, npu_op: NpuBlockOperation) -> int:
+ # Note: NOT equivalent to the normal ifm block depth calculation since
+ # it takes into account 'depthless' block operations by returning full
+ # depth
+ if npu_op.op_type == NpuOperationType.Conv2D:
+ res = arch.calc_ifm_block_depth(npu_op.ifm.shape.depth, npu_op.ifm.data_type.size_in_bits())
+ return res
+ return npu_op.ofm.shape.depth
+
+
+def coords_intersect(start_a: PointXYZ, end_a: PointXYZ, start_b: PointXYZ, end_b: PointXYZ) -> bool:
+ """Checks if the two areas overlap"""
+ start_x = max(start_a.x, start_b.x)
+ end_x = min(end_a.x, end_b.x)
+ start_y = max(start_a.y, start_b.y)
+ end_y = min(end_a.y, end_b.y)
+ start_z = max(start_a.z, start_b.z)
+ end_z = min(end_a.z, end_b.z)
+ return ((end_x - start_x) > 0) and ((end_y - start_y) > 0) and ((end_z - start_z) > 0)
+
+
+def intersects(
+ ifm: NpuFeatureMap,
+ ifm_start_coord: PointXYZ,
+ ifm_end_coord: PointXYZ,
+ prev_ofm: NpuFeatureMap,
+ ofm_start_coord: PointXYZ,
+ ofm_end_coord: PointXYZ,
+) -> bool:
+ """Checks if the given IFM area overlaps with the given OFM area"""
+ if ifm.shape == prev_ofm.shape and ifm.tiles == prev_ofm.tiles:
+ # Common case: prev_op.ofm == op.ifm; in this case it suffices to check
+ # if the xyz coordinates overlap, which is quick and easy
+ res = coords_intersect(ifm_start_coord, ifm_end_coord, ofm_start_coord, ofm_end_coord)
+ else:
+ # The OFM produces a part of the IFM (e.g. a stripe), or the IFM consumes part of the OFM.
+ # In this case, address comparison between the two areas is needed
+ ifm_ranges = get_address_ranges_for_area(ifm, ifm_start_coord, ifm_end_coord)
+ prev_ofm_ranges = get_address_ranges_for_area(prev_ofm, ofm_start_coord, ofm_end_coord)
+ res = range_lists_overlap(ifm_ranges, prev_ofm_ranges)
+ return res
+
+
+# Block job dependency:
+# Does the VOLUME of IFMs for block job B(0) overlap with VOLUME of OFMs block jobs A(8,9,10)
+#
+# A | B
+# ----------------------+------------------
+# .... 3,4,5,6,7,8,9,10 | 0,1,2,3,4,5,6,8 10 < JOB NUMBER
+# |<------->| dependency offset
+#
+
+
+def get_offset_block_coords(area: Rect, block: Block, offset: int) -> Optional[PointXYZ]:
+ """
+ Get the coordinates of a block offset from either the end (negative)
+ or the start (zero or positive) of the given 3D area
+ """
+ size = area.size()
+ # Dimensions of the region, in blocks
+ width_blocks = numeric_util.round_up_divide(size.width, block.width)
+ height_blocks = numeric_util.round_up_divide(size.height, block.height)
+ depth_blocks = numeric_util.round_up_divide(size.depth, block.depth)
+ total_blocks = width_blocks * height_blocks * depth_blocks
+ if offset < 0:
+ index = total_blocks + offset
+ else:
+ index = offset
+
+ if index >= total_blocks:
+ return None
+
+ # Coordinates of the indexed block
+ coord_z = block.depth * (index % depth_blocks)
+ coord_y = block.height * (index // (depth_blocks * width_blocks))
+ coord_x = block.width * ((index // depth_blocks) % width_blocks)
+
+ return PointXYZ(x=coord_x + area.x, y=coord_y + area.y, z=coord_z + area.z)
+
+
+def get_first_job_input_volume(
+ arch: ArchitectureFeatures,
+ ifm: Rect,
+ ofm: Rect,
+ ifm_block_depth,
+ ofm_block: Block,
+ kernel: Kernel,
+ padding: NpuPadding,
+ block_offset: int,
+):
+ # Get ifm block size (jobs are invisibly decomposed into subkernels)
+ ifm_block = arch.get_ifm_block_size(ifm_block_depth, ofm_block, kernel, arch.ofm_block_max)
+ ifm_depth_blocks = numeric_util.round_up_divide(ifm.size().depth, ifm_block_depth)
+
+ # Which OFM block are we calculating
+ ofm_coord = get_offset_block_coords(ofm, ofm_block, block_offset // ifm_depth_blocks)
+ if ofm_coord is None:
+ return None
+
+ # Coordinate of the source IFM block
+ ifm_coord_x = max(0, ofm_coord[0] * kernel.stride.x - padding.left)
+ ifm_coord_y = max(0, ofm_coord[1] * kernel.stride.y - padding.right)
+ ifm_coord_z = ifm.z + (block_offset % ifm_depth_blocks) * ifm_block.depth
+
+ # IFM block that will be sampled for the FIRST+block_offset job in the next operator's OFM
+ start_coord = PointXYZ(x=ifm_coord_x, y=ifm_coord_y, z=ifm_coord_z)
+ end_coord = PointXYZ(
+ x=start_coord[0] + ifm_block.width, y=start_coord[1] + ifm_block.height, z=start_coord[2] + ifm_block.depth,
+ )
+ return (start_coord, end_coord, 1) # start, end, total jobs
+
+
+def get_prev_job_output_volume(ofm: Rect, ofm_block: Block, block_offset: int):
+ assert block_offset >= 0
+
+ # Get OFM block's volume coordinates
+ start_coord = get_offset_block_coords(ofm, ofm_block, -1 - block_offset)
+ if start_coord is None:
+ return None
+ end_coord = PointXYZ(
+ x=start_coord.x + ofm_block.width, y=start_coord.y + ofm_block.height, z=start_coord.z + ofm_block.depth,
+ )
+ return (start_coord, end_coord, 1) # start, end, total jobs for this OFM block
+
+
+def calc_blockdep(arch: ArchitectureFeatures, prev_op: Optional[NpuBlockOperation], npu_op: NpuBlockOperation,) -> int:
+ """Calculates the value of the BLOCKDEP register"""
+ if prev_op is None:
+ return 0
+ assert npu_op.ifm is not None
+ assert prev_op.ofm is not None
+ # Check if IFM or IFM2 overlaps with prev op's OFM
+ prev_ofm_ranges = get_address_ranges(prev_op.ofm)
+ ifm_ranges = get_address_ranges(npu_op.ifm)
+ ifm_overlaps = range_lists_overlap(prev_ofm_ranges, ifm_ranges)
+ if has_ifm2(npu_op):
+ assert npu_op.ifm2 is not None
+ ifm2_ranges = get_address_ranges(npu_op.ifm2)
+ ifm2_overlaps = range_lists_overlap(prev_ofm_ranges, ifm2_ranges)
+ else:
+ ifm2_overlaps = False
+ if ifm_overlaps and ifm2_overlaps:
+ # Both IFM and IFM2 overlap (should be rare)
+ return 0
+ if not ifm_overlaps and not ifm2_overlaps:
+ # No overlap between prev OFM and IFM/IFM2
+ return ArchitectureFeatures.MAX_BLOCKDEP
+ if ifm2_overlaps and shape3d_size(npu_op.ifm2.shape) < shape3d_size(npu_op.ifm.shape):
+ # Prev OFM produces IFM2 which is broadcasted (this should be rare)
+ return 0
+ # Prev OFM overlaps with IFM or IFM2; calculate the blockdep
+ prev_block_config = prev_op.block_config
+ block_config = npu_op.block_config
+ overlapping_fm = npu_op.ifm if ifm_overlaps else npu_op.ifm2
+ assert overlapping_fm is not None
+
+ cur_ifm_block_depth = get_ifm_ofm_block_depth(arch, npu_op)
+ cur_ofm_block = Block(block_config.width, block_config.height, block_config.depth)
+ cur_ofm_rect = shape3d_to_rect(npu_op.ofm.shape)
+ cur_ifm_rect = shape3d_to_rect(npu_op.ifm.shape)
+ padding = NpuPadding(0, 0, 0, 0) if npu_op.padding is None else npu_op.padding
+ blockdep = ArchitectureFeatures.MAX_BLOCKDEP
+ kernel = to_kernel(npu_op.kernel)
+
+ prev_ofm_block = Block(prev_block_config.width, prev_block_config.height, prev_block_config.depth)
+ prev_ofm_rect = shape3d_to_rect(prev_op.ofm.shape)
+ # Iterate over the next BLOCKDEP inputs, checking to see if a sliding window
+ # of IFM area overlaps with any previous OFM block generation.
+ elapsed_jobs = 0
+ for forward_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
+ # This is the IFM block we want to sample from
+ in_area = get_first_job_input_volume(
+ arch, cur_ifm_rect, cur_ofm_rect, cur_ifm_block_depth, cur_ofm_block, kernel, padding, forward_offset
+ )
+ if in_area is None:
+ break
+
+ # Try several previous-OFM blocks in the past (they still might comprise multiple IFM jobs)
+ outstanding_jobs = 0
+ for block_offset in range(ArchitectureFeatures.MAX_BLOCKDEP):
+ # This is the OFM block being generated by the previous op
+ out_area = get_prev_job_output_volume(prev_ofm_rect, prev_ofm_block, block_offset)
+ if out_area is None:
+ break
+
+ # Block dependency is the max number of allowed outstanding jobs
+ # in the pipeline. Selected by determining how many jobs occur
+ # in between two operators' overlapping OFM->IFM block volumes
+ if intersects(overlapping_fm, in_area[0], in_area[1], prev_op.ofm, out_area[0], out_area[1]):
+ break
+ # Early exit if no intersections and we've seen enough jobs in the pipeline
+ elif outstanding_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
+ break
+
+ # This OFM had this many jobs (accumulate over multiple OFM blocks)
+ outstanding_jobs += out_area[2]
+
+ blockdep = min(blockdep, elapsed_jobs + outstanding_jobs)
+ elapsed_jobs += in_area[2]
+ # Early exit if no intersections and we've seen enough jobs in the pipeline
+ if elapsed_jobs > ArchitectureFeatures.MAX_BLOCKDEP:
+ break
+
+ return blockdep
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index ee559625..21b048bc 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -28,10 +28,10 @@ from .architecture_features import SharedBufferArea
from .architecture_features import SHRAMElements
from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
-from .high_level_command_to_npu_op import to_kernel
from .operation import Kernel
from .operation import NpuBlockType
from .range_set import MemoryRangeSet
+from .register_command_stream_util import to_kernel
from .tensor import MemArea
diff --git a/ethosu/vela/test/extapi/test_extapi_generate_commands.py b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
index 812991a9..b605dfc5 100644
--- a/ethosu/vela/test/extapi/test_extapi_generate_commands.py
+++ b/ethosu/vela/test/extapi/test_extapi_generate_commands.py
@@ -41,7 +41,7 @@ from ethosu.vela.api import NpuTileBox
from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd0
from ethosu.vela.ethos_u55_regs.ethos_u55_regs import cmd1
from ethosu.vela.register_command_stream_generator import CmdMode
-from ethosu.vela.register_command_stream_generator import get_address_ranges
+from ethosu.vela.register_command_stream_util import get_address_ranges
def check_cmd0(cmd_stream, cmd, param):
diff --git a/ethosu/vela/test/test_register_command_generator.py b/ethosu/vela/test/test_register_command_stream_util.py
index 2760c860..985523fa 100644
--- a/ethosu/vela/test/test_register_command_generator.py
+++ b/ethosu/vela/test/test_register_command_stream_util.py
@@ -32,8 +32,8 @@ from ethosu.vela.api import NpuTileBox
from ethosu.vela.architecture_features import Accelerator
from ethosu.vela.architecture_features import create_default_arch
from ethosu.vela.register_command_stream_generator import calc_blockdep
-from ethosu.vela.register_command_stream_generator import get_address_ranges
from ethosu.vela.register_command_stream_generator import get_strides
+from ethosu.vela.register_command_stream_util import get_address_ranges
from ethosu.vela.test.extapi.test_extapi_generate_commands import create_feature_map