aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2020-12-08 17:56:44 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-14 07:00:24 +0000
commit9b43f846b144d39bfb0cf16853bf6901c74b6672 (patch)
treea530dce790bb8e54dad009e11ca4d49d54b52b1d
parent94457b175b8646bce089c9924e99686587de8992 (diff)
downloadethos-u-vela-9b43f846b144d39bfb0cf16853bf6901c74b6672.tar.gz
MLBEDSW-3653: Fix type errors in annotated files
This commit corrects a number of type errors reported by mypy and refactors some parts of the code which are no longer necessary after making adjustments to satisfy mypy. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: I16b880b228e57f2a92fb8936f53e94886e0f9f44
-rw-r--r--ethosu/vela/__init__.py3
-rw-r--r--ethosu/vela/data_type.py29
-rw-r--r--ethosu/vela/debug_database.py21
-rw-r--r--ethosu/vela/driver_actions.py2
-rw-r--r--ethosu/vela/high_level_command_stream.py12
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py17
-rw-r--r--ethosu/vela/lut.py7
-rw-r--r--ethosu/vela/operation.py13
-rw-r--r--ethosu/vela/register_command_stream_generator.py92
-rw-r--r--ethosu/vela/register_command_stream_util.py11
10 files changed, 120 insertions, 87 deletions
diff --git a/ethosu/vela/__init__.py b/ethosu/vela/__init__.py
index 90376be3..77c171d0 100644
--- a/ethosu/vela/__init__.py
+++ b/ethosu/vela/__init__.py
@@ -14,6 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._version import __version__
-from .vela import main
-__all__ = [main, __version__]
+__all__ = ["main", __version__]
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py
index a4b7b537..3ad642ad 100644
--- a/ethosu/vela/data_type.py
+++ b/ethosu/vela/data_type.py
@@ -16,6 +16,7 @@
# Description:
# Defines the basic numeric type classes for tensors.
import enum
+from typing import Any
from .numeric_util import round_up_divide
@@ -43,6 +44,34 @@ class DataType:
__slots__ = "type", "bits"
+ int8: Any
+ int16: Any
+ int32: Any
+ int64: Any
+ uint8: Any
+ uint16: Any
+ uint32: Any
+ uint64: Any
+ quint4: Any
+ quint8: Any
+ quint12: Any
+ quint16: Any
+ quint32: Any
+ qint4: Any
+ qint8: Any
+ qint12: Any
+ qint16: Any
+ qint32: Any
+ float16: Any
+ float32: Any
+ float64: Any
+ string: Any
+ bool: Any
+ resource: Any
+ variant: Any
+ complex64: Any
+ complex128: Any
+
def __init__(self, type_, bits):
self.type = type_
self.bits = bits
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
index b5852cdc..4f0a50ae 100644
--- a/ethosu/vela/debug_database.py
+++ b/ethosu/vela/debug_database.py
@@ -15,6 +15,9 @@
# limitations under the License.
import csv
import io
+from typing import Any
+from typing import Dict
+from typing import List
import lxml.etree as xml
@@ -22,28 +25,32 @@ from . import numeric_util
from .operation import Operation
+UntypedDict = Dict[Any, Any]
+UntypedList = List[Any]
+
+
class DebugDatabase:
NULLREF = -1
show_warnings = False
SOURCE_TABLE = "source"
- _sourceUID = {}
+ _sourceUID: UntypedDict = {}
_sourceHeaders = ["id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
- _sourceTable = []
+ _sourceTable: UntypedList = []
OPTIMISED_TABLE = "optimised"
- _optimisedUID = {}
+ _optimisedUID: UntypedDict = {}
_optimisedHeaders = ["id", "source_id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
- _optimisedTable = []
+ _optimisedTable: UntypedList = []
QUEUE_TABLE = "queue"
_queueHeaders = ["offset", "cmdstream_id", "optimised_id"]
- _queueTable = []
+ _queueTable: UntypedList = []
STREAM_TABLE = "cmdstream"
- _streamUID = {}
+ _streamUID: UntypedDict = {}
_streamHeaders = ["id", "file_offset"]
- _streamTable = []
+ _streamTable: UntypedList = []
@classmethod
def add_source(cls, op: Operation):
diff --git a/ethosu/vela/driver_actions.py b/ethosu/vela/driver_actions.py
index 86bed110..5a85df06 100644
--- a/ethosu/vela/driver_actions.py
+++ b/ethosu/vela/driver_actions.py
@@ -117,7 +117,7 @@ def create_driver_payload(register_command_stream: List[int], arch: Architecture
"""Creates driver header and includes the given command
"""
# Prepare driver actions for this command tensor
- da_list = []
+ da_list: List[int] = []
emit_fourcc(da_list, "COP1")
emit_config(da_list, 0, 1, arch)
emit_cmd_stream_header(da_list, len(register_command_stream))
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index d057d17e..c45bc4e5 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -15,8 +15,6 @@
# limitations under the License.
# Description:
# Contains classes that hold commands for the high-level command stream (one command per DMA or NPU stripe).
-from enum import IntEnum
-
import numpy as np
from .architecture_features import Block
@@ -144,12 +142,6 @@ class Box:
__repr__ = __str__
-class CommandType(IntEnum):
- NpuStripe = 0
- DMA = 1
- Size = 2
-
-
class Command:
def get_ofm_y_range_for_pass(self, ps_requested):
return None
@@ -158,7 +150,7 @@ class Command:
return False
def get_operation_count(self):
- # returns numpy array of (DPU blocks, dma_ops). Should line up with the CommandType enum
+ # returns numpy array of (DPU blocks, dma_ops).
return np.array((0, 0))
@@ -185,7 +177,6 @@ class NpuStripe(Command):
pad_top=0,
pad_bottom=0,
):
- self.cmdtype = CommandType.NpuStripe
self.ps = ps
self.block_config = block_config
self.is_first = is_first
@@ -333,7 +324,6 @@ class NpuStripe(Command):
class DMA(Command):
def __init__(self, ps, in_tensor, out_tensor, box):
- self.cmdtype = CommandType.DMA
self.ps = ps
self.in_tensor = in_tensor
self.out_tensor = out_tensor
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 7db4931d..9e0ed010 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -48,7 +48,6 @@ from .data_type import DataType
from .debug_database import DebugDatabase
from .high_level_command_stream import Box
from .high_level_command_stream import Command
-from .high_level_command_stream import CommandType
from .high_level_command_stream import DMA
from .high_level_command_stream import NpuStripe
from .operation import NpuBlockType
@@ -56,7 +55,6 @@ from .operation import Op
from .operation import Operation
from .register_command_stream_generator import generate_command_stream
from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
-from .register_command_stream_util import is_dma_op
from .register_command_stream_util import to_npu_kernel
from .register_command_stream_util import UNARY_ELEMWISE_OPS
from .tensor import MemType
@@ -168,7 +166,7 @@ def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
else:
base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
- return int(base_ptr_idx_map[tens.mem_type])
+ return base_ptr_idx_map[tens.mem_type].value
def get_upscale(op: Operation) -> NpuResamplingMode:
@@ -461,9 +459,10 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
"""Converts the high level command to NpuOperation"""
- if cmd.cmdtype == CommandType.DMA:
+ npu_op: NpuOperation
+ if isinstance(cmd, DMA):
npu_op = create_dma_op(cmd, arch)
- elif cmd.cmdtype == CommandType.NpuStripe:
+ elif isinstance(cmd, NpuStripe):
npu_block_type = cmd.ps.primary_op.type.npu_block_type
if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
npu_op = create_npu_conv2d_op(cmd, arch)
@@ -475,8 +474,6 @@ def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOp
npu_op = create_npu_elementwise_op(cmd, arch)
else:
assert 0, f"Unknown command type {npu_block_type}"
- # add a link to the high level command for debugging purposes
- npu_op.cmd = cmd
return npu_op
@@ -486,7 +483,7 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
npu_op_list = []
npu_op_to_cmd = dict() # map from npu op to high level command
for cmd in sg.high_level_command_stream:
- if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
+ if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
print("Warning: Skipping register command stream generation for", cmd.ps)
else:
npu_op = convert_command_to_npu_op(cmd, arch)
@@ -498,8 +495,8 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
def add_to_debug_db(npu_op: NpuOperation, offset: int):
"""Adds info to the debug database"""
- if not is_dma_op(npu_op):
+ if not isinstance(npu_op, NpuDmaOperation):
cmd = npu_op_to_cmd[npu_op]
DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
- sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db)
+ sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index 8e28b953..8a23b51d 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -20,7 +20,8 @@ import uuid
import numpy as np
from . import numeric_util
-from .high_level_command_stream import CommandType
+from .high_level_command_stream import DMA
+from .high_level_command_stream import NpuStripe
from .tensor import create_const_tensor
from .tensor import create_equivalence_id
from .tensor import TensorPurpose
@@ -101,11 +102,11 @@ def optimize_high_level_cmd_stream(sg, arch):
lut_start = arch.shram_lut_address
lut_end = lut_start + arch.shram_lut_size
for cmd in sg.high_level_command_stream:
- if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
+ if isinstance(cmd, NpuStripe) and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0:
# The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA
# TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not
lut_state = LUTState()
- if cmd.cmdtype != CommandType.DMA or cmd.out_tensor.purpose != TensorPurpose.LUT:
+ if not isinstance(cmd, DMA) or cmd.out_tensor.purpose != TensorPurpose.LUT:
# Non-LUT operation; leave untouched
cmd_stream.append(cmd)
continue
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 45fae217..32cba365 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -18,10 +18,17 @@
import copy
from collections import namedtuple
from enum import Enum
+from typing import Any
+from typing import Dict
+from typing import List
from typing import Optional
+from typing import TYPE_CHECKING
from .numeric_util import full_shape
+if TYPE_CHECKING:
+ from .tensor import Tensor
+
PointXY = namedtuple("PointXY", "x y")
PointXYZ = namedtuple("PointXYZ", "x y z")
@@ -392,9 +399,9 @@ class Operation:
def __init__(self, op_type: Op, name: str):
self.type = op_type
self.name = name
- self.attrs = {}
- self.inputs = []
- self.outputs = []
+ self.attrs: Dict[str, Any] = {}
+ self.inputs: List[Tensor] = []
+ self.outputs: List[Tensor] = []
self.flops = 0
self.run_on_npu = True
# Fused activation function. If not none: operator code.
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index d4947b1a..fa56d353 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -20,6 +20,7 @@
from collections import defaultdict
from enum import Enum
from enum import IntEnum
+from typing import Dict
from typing import List
from typing import Optional
@@ -33,6 +34,7 @@ from .api import NpuAddressRange
from .api import NpuBlockOperation
from .api import NpuBlockTraversal
from .api import NpuConv2DOperation
+from .api import NpuConvDepthWiseOperation
from .api import NpuDataType
from .api import NpuDmaOperation
from .api import NpuElementWiseOp
@@ -68,13 +70,13 @@ from .numeric_util import quantise_float32
from .numeric_util import round_away_zero
from .numeric_util import round_up_to_int
from .operation import NpuBlockType
+from .range_set import MemoryAccessSet
from .register_command_stream_util import calc_blockdep
from .register_command_stream_util import get_dma_memory_accesses
from .register_command_stream_util import get_op_memory_accesses
from .register_command_stream_util import get_strides
from .register_command_stream_util import get_wait_dependency
from .register_command_stream_util import has_ifm2
-from .register_command_stream_util import is_dma_op
from .register_command_stream_util import to_kernel
from .register_command_stream_util import UNARY_ELEMWISE_OPS
from .register_command_stream_util import Watermark
@@ -549,15 +551,13 @@ def generate_shram_registers_non_elementwise(emit: CommandStreamEmitter, shared_
def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation:
"""Creates shared buffer allocation for the given operation"""
- op_type = npu_op.op_type
- block_type = NpuBlockType.Default
- if op_type == NpuOperationType.Conv2D:
+ if isinstance(npu_op, NpuConv2DOperation):
block_type = NpuBlockType.ConvolutionMxN
- elif op_type == NpuOperationType.ConvDepthWise:
+ elif isinstance(npu_op, NpuConvDepthWiseOperation):
block_type = NpuBlockType.ConvolutionDepthWise
- elif op_type == NpuOperationType.Pooling:
+ elif isinstance(npu_op, NpuPoolingOperation):
block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling
- elif op_type == NpuOperationType.ElementWise:
+ elif isinstance(npu_op, NpuElementWiseOperation):
block_type = NpuBlockType.ElementWise
else:
assert 0, "Unsupported operation"
@@ -599,7 +599,7 @@ def generate_common(
generate_activation(emit, npu_op.activation, npu_op.ofm)
shared_buffer = create_shared_buffer(npu_op, arch)
generate_block_config(emit, npu_op, arch, shared_buffer)
- if npu_op.op_type == NpuOperationType.ElementWise:
+ if isinstance(npu_op, NpuElementWiseOperation):
generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer)
else:
generate_shram_registers_non_elementwise(emit, shared_buffer)
@@ -746,17 +746,20 @@ def print_feature_map(fm: NpuFeatureMap, name: str):
print(f" {stride_str}, tiles: w0={t.width_0}, h0={t.height_0}, h1={t.height_1}, base={addresses}")
-def print_operation(npu_op: NpuOperation, index: int = 0):
- pass_info = f", {npu_op.cmd}" if hasattr(npu_op, "cmd") else ""
- if is_dma_op(npu_op):
+def print_operation(npu_op: NpuOperation, index: int = 0, cmd=None):
+ pass_info = f", {cmd}" if cmd else ""
+ if isinstance(npu_op, NpuOperation) and not isinstance(npu_op, (NpuDmaOperation, NpuBlockOperation)):
+ print(f"{index} {npu_op.op_type.name}{pass_info}")
+ return
+ if isinstance(npu_op, NpuDmaOperation):
print(f"{index} DMA_START src={npu_op.src}, dest={npu_op.dest}{pass_info}")
return
k = None if npu_op.kernel is None else to_kernel(npu_op.kernel)
- if npu_op.op_type in (NpuOperationType.Pooling, NpuOperationType.ElementWise):
+ if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)):
print(f"{index} {npu_op.sub_op_type.name} {npu_op.op_type.name}:{pass_info}")
else:
if (
- npu_op.op_type == NpuOperationType.Conv2D
+ isinstance(npu_op, NpuConv2DOperation)
and k.elements_wh() * k.stride.x * k.stride.y * k.dilation.x * k.dilation.y == 1
):
fc = "FullyConnected "
@@ -783,16 +786,19 @@ def print_operation(npu_op: NpuOperation, index: int = 0):
if act.op_type != NpuActivationOp.NONE_OR_RELU or act.min is not None or act.max is not None:
lut = f", lut index={act.lookup_table_index}" if act.op_type == NpuActivationOp.TABLE_LOOKUP else ""
print(f" Activation: {act.op_type.name}, min={act.min}, max={act.max}{lut}")
- if npu_op.op_type == NpuOperationType.Conv2D:
+ if isinstance(npu_op, NpuConv2DOperation):
print(f" {npu_op.block_traversal}")
bh, bw, bc = npu_op.block_config
- rescale = f", rescale={npu_op.rescale}" if hasattr(npu_op, "rescale") else ""
+ rescale = (
+ f", rescale={npu_op.rescale}" if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)) else ""
+ )
print(f" Block config: h={bh},w={bw},c={bc}, {npu_op.ifm_upscale}, {npu_op.rounding_mode}{rescale}")
-def print_operations(npu_op_list: List[NpuOperation]):
+def print_operations(npu_op_list: List[NpuOperation], npu_op_to_cmd=None):
+ npu_op_to_cmd = dict() if npu_op_to_cmd is None else npu_op_to_cmd
for index, npu_op in enumerate(npu_op_list):
- print_operation(npu_op, index)
+ print_operation(npu_op, index, npu_op_to_cmd.get(npu_op))
# -------------------------------------------------------------------
@@ -802,16 +808,15 @@ def print_operations(npu_op_list: List[NpuOperation]):
def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation):
"""Generates NPU_OP_* command"""
- op_type = npu_op.op_type
- if op_type == NpuOperationType.Dma:
+ if isinstance(npu_op, NpuDmaOperation):
emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, npu_op.channel * 16 + npu_op.mode)
- elif op_type == NpuOperationType.Conv2D:
+ elif isinstance(npu_op, NpuConv2DOperation):
emit.cmd_do_operation(cmd0.NPU_OP_CONV)
- elif op_type == NpuOperationType.ConvDepthWise:
+ elif isinstance(npu_op, NpuConvDepthWiseOperation):
emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE)
- elif op_type == NpuOperationType.Pooling:
+ elif isinstance(npu_op, NpuPoolingOperation):
emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_op_map[npu_op.sub_op_type])
- elif op_type == NpuOperationType.ElementWise:
+ elif isinstance(npu_op, NpuElementWiseOperation):
emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param=elementwise_op_map[npu_op.sub_op_type])
else:
assert 0, "Unsupported operation"
@@ -822,7 +827,9 @@ def generate_conv2d_op(emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, a
generate_common(emit, npu_op, npu_op.block_traversal, arch)
-def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures):
+def generate_conv_depthwise_op(
+ emit: CommandStreamEmitter, npu_op: NpuConvDepthWiseOperation, arch: ArchitectureFeatures
+):
"""Generates register commands for depthwise convolution operations"""
generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch)
@@ -880,23 +887,22 @@ def generate_registers_for_op(emit: CommandStreamEmitter, npu_op: NpuOperation,
Generates register commands for the given operation, but not the final NPU_OP_... command.
Returns the selected block config
"""
- op_type = npu_op.op_type
- if op_type == NpuOperationType.Conv2D:
+ if isinstance(npu_op, NpuConv2DOperation):
generate_conv2d_op(emit, npu_op, arch)
- elif op_type == NpuOperationType.ConvDepthWise:
+ elif isinstance(npu_op, NpuConvDepthWiseOperation):
generate_conv_depthwise_op(emit, npu_op, arch)
- elif op_type == NpuOperationType.Pooling:
+ elif isinstance(npu_op, NpuPoolingOperation):
generate_pooling_op(emit, npu_op, arch)
- elif op_type == NpuOperationType.ElementWise:
+ elif isinstance(npu_op, NpuElementWiseOperation):
generate_elementwise_op(emit, npu_op, arch)
- elif op_type == NpuOperationType.Dma:
+ elif isinstance(npu_op, NpuDmaOperation):
generate_dma_op(emit, npu_op)
else:
assert 0, "Unsupported operation"
def generate_command_stream(
- npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None,
+ npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None, npu_op_to_cmd=None
) -> List[int]:
"""
Generates register commands for the given list of NPU operations.
@@ -904,14 +910,16 @@ def generate_command_stream(
"""
emit = CommandStreamEmitter()
if verbose:
- print_operations(npu_op_list)
+ print_operations(npu_op_list, npu_op_to_cmd)
# Calculate memory accesses for every operation
- memory_accesses = {}
+ memory_accesses: Dict[NpuOperation, MemoryAccessSet] = {}
for npu_op in npu_op_list:
- if is_dma_op(npu_op):
+ if isinstance(npu_op, NpuDmaOperation):
memory_accesses[npu_op] = get_dma_memory_accesses(npu_op)
- else:
+ elif isinstance(npu_op, NpuBlockOperation):
memory_accesses[npu_op] = get_op_memory_accesses(npu_op, arch)
+ else:
+ assert 0, "Invalid operation type"
if arch.is_ethos_u65_system:
emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1)
dep_watermark = Watermark(0, 0)
@@ -920,7 +928,7 @@ def generate_command_stream(
for op_index, npu_op in enumerate(npu_op_list):
dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark)
generate_registers_for_op(emit, npu_op, arch)
- if not is_dma_op(npu_op):
+ if not isinstance(npu_op, NpuDmaOperation) and isinstance(npu_op, NpuBlockOperation):
# Generate BLOCKDEP
blockdep = calc_blockdep(arch, prev_op, npu_op)
blockdep = min(blockdep, arch.max_blockdep)
@@ -951,12 +959,12 @@ def find_block_configs(npu_op: NpuOperation, npu_accelerator: NpuAccelerator) ->
"""
Internal implementation of the public facing API for finding block configs.
"""
- if is_dma_op(npu_op):
- return []
- arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
- shared_buffer = create_shared_buffer(npu_op, arch)
- blocks = find_suitable_block_configs(arch, shared_buffer)
- return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+ if isinstance(npu_op, NpuBlockOperation):
+ arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator))
+ shared_buffer = create_shared_buffer(npu_op, arch)
+ blocks = find_suitable_block_configs(arch, shared_buffer)
+ return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks]
+ return []
def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]:
diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py
index ce49fc29..55fa620c 100644
--- a/ethosu/vela/register_command_stream_util.py
+++ b/ethosu/vela/register_command_stream_util.py
@@ -68,11 +68,6 @@ def has_ifm2(npu_op: NpuBlockOperation) -> bool:
return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None
-def is_dma_op(npu_op: NpuOperation) -> bool:
- """Checks if op is a DMA operation"""
- return npu_op.op_type == NpuOperationType.Dma
-
-
def shape3d_size(shape: NpuShape3D) -> int:
return shape.width * shape.height * shape.depth
@@ -302,9 +297,9 @@ def get_wait_dependency(
prev_access = memory_accesses[prev_op]
# Check NPU consuming DMA output
- if is_dma_op(prev_op):
+ if isinstance(prev_op, NpuDmaOperation):
if index >= dma_index:
- if not is_dma_op(npu_op):
+ if not isinstance(npu_op, NpuDmaOperation):
if (dma_outstanding == -1) and prev_access.conflicts(op_access):
dma_outstanding = dma_ops
dma_ops += 1 # Count DMA ops in the pipeline
@@ -313,7 +308,7 @@ def get_wait_dependency(
# Check DMA consuming NPU output
else:
if index >= npu_index:
- if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access):
+ if isinstance(npu_op, NpuDmaOperation) and npu_outstanding == -1 and prev_access.conflicts(op_access):
npu_outstanding = npu_ops
npu_ops += 1 # Count NPU ops in the pipeline
if npu_ops >= arch.max_outstanding_kernels: