diff options
author | Dwight Lidman <dwight.lidman@arm.com> | 2020-12-08 17:56:44 +0100 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-12-14 07:00:24 +0000 |
commit | 9b43f846b144d39bfb0cf16853bf6901c74b6672 (patch) | |
tree | a530dce790bb8e54dad009e11ca4d49d54b52b1d | |
parent | 94457b175b8646bce089c9924e99686587de8992 (diff) | |
download | ethos-u-vela-9b43f846b144d39bfb0cf16853bf6901c74b6672.tar.gz |
MLBEDSW-3653: Fix type errors in annotated files
This commit corrects a number of type errors
reported by mypy and refactors some parts of
the code which are no longer necessary after
making adjustments to satisfy mypy.
Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I16b880b228e57f2a92fb8936f53e94886e0f9f44
-rw-r--r-- | ethosu/vela/__init__.py | 3 | ||||
-rw-r--r-- | ethosu/vela/data_type.py | 29 | ||||
-rw-r--r-- | ethosu/vela/debug_database.py | 21 | ||||
-rw-r--r-- | ethosu/vela/driver_actions.py | 2 | ||||
-rw-r--r-- | ethosu/vela/high_level_command_stream.py | 12 | ||||
-rw-r--r-- | ethosu/vela/high_level_command_to_npu_op.py | 17 | ||||
-rw-r--r-- | ethosu/vela/lut.py | 7 | ||||
-rw-r--r-- | ethosu/vela/operation.py | 13 | ||||
-rw-r--r-- | ethosu/vela/register_command_stream_generator.py | 92 | ||||
-rw-r--r-- | ethosu/vela/register_command_stream_util.py | 11 |
10 files changed, 120 insertions, 87 deletions
diff --git a/ethosu/vela/__init__.py b/ethosu/vela/__init__.py index 90376be3..77c171d0 100644 --- a/ethosu/vela/__init__.py +++ b/ethosu/vela/__init__.py @@ -14,6 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from ._version import __version__ -from .vela import main -__all__ = [main, __version__] +__all__ = ["main", __version__] diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py index a4b7b537..3ad642ad 100644 --- a/ethosu/vela/data_type.py +++ b/ethosu/vela/data_type.py @@ -16,6 +16,7 @@ # Description: # Defines the basic numeric type classes for tensors. import enum +from typing import Any from .numeric_util import round_up_divide @@ -43,6 +44,34 @@ class DataType: __slots__ = "type", "bits" + int8: Any + int16: Any + int32: Any + int64: Any + uint8: Any + uint16: Any + uint32: Any + uint64: Any + quint4: Any + quint8: Any + quint12: Any + quint16: Any + quint32: Any + qint4: Any + qint8: Any + qint12: Any + qint16: Any + qint32: Any + float16: Any + float32: Any + float64: Any + string: Any + bool: Any + resource: Any + variant: Any + complex64: Any + complex128: Any + def __init__(self, type_, bits): self.type = type_ self.bits = bits diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py index b5852cdc..4f0a50ae 100644 --- a/ethosu/vela/debug_database.py +++ b/ethosu/vela/debug_database.py @@ -15,6 +15,9 @@ # limitations under the License. import csv import io +from typing import Any +from typing import Dict +from typing import List import lxml.etree as xml @@ -22,28 +25,32 @@ from . import numeric_util from .operation import Operation +UntypedDict = Dict[Any, Any] +UntypedList = List[Any] + + class DebugDatabase: NULLREF = -1 show_warnings = False SOURCE_TABLE = "source" - _sourceUID = {} + _sourceUID: UntypedDict = {} _sourceHeaders = ["id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"] - _sourceTable = [] + _sourceTable: UntypedList = [] OPTIMISED_TABLE = "optimised" - _optimisedUID = {} + _optimisedUID: UntypedDict = {} _optimisedHeaders = ["id", "source_id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"] - _optimisedTable = [] + _optimisedTable: UntypedList = [] QUEUE_TABLE = "queue" _queueHeaders = ["offset", "cmdstream_id", "optimised_id"] - _queueTable = [] + _queueTable: UntypedList = [] STREAM_TABLE = "cmdstream" - _streamUID = {} + _streamUID: UntypedDict = {} _streamHeaders = ["id", "file_offset"] - _streamTable = [] + _streamTable: UntypedList = [] @classmethod def add_source(cls, op: Operation): diff --git a/ethosu/vela/driver_actions.py b/ethosu/vela/driver_actions.py index 86bed110..5a85df06 100644 --- a/ethosu/vela/driver_actions.py +++ b/ethosu/vela/driver_actions.py @@ -117,7 +117,7 @@ def create_driver_payload(register_command_stream: List[int], arch: Architecture """Creates driver header and includes the given command """ # Prepare driver actions for this command tensor - da_list = [] + da_list: List[int] = [] emit_fourcc(da_list, "COP1") emit_config(da_list, 0, 1, arch) emit_cmd_stream_header(da_list, len(register_command_stream)) diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index d057d17e..c45bc4e5 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -15,8 +15,6 @@ # limitations under the License. # Description: # Contains classes that hold commands for the high-level command stream (one command per DMA or NPU stripe). -from enum import IntEnum - import numpy as np from .architecture_features import Block @@ -144,12 +142,6 @@ class Box: __repr__ = __str__ -class CommandType(IntEnum): - NpuStripe = 0 - DMA = 1 - Size = 2 - - class Command: def get_ofm_y_range_for_pass(self, ps_requested): return None @@ -158,7 +150,7 @@ class Command: return False def get_operation_count(self): - # returns numpy array of (DPU blocks, dma_ops). Should line up with the CommandType enum + # returns numpy array of (DPU blocks, dma_ops). return np.array((0, 0)) @@ -185,7 +177,6 @@ class NpuStripe(Command): pad_top=0, pad_bottom=0, ): - self.cmdtype = CommandType.NpuStripe self.ps = ps self.block_config = block_config self.is_first = is_first @@ -333,7 +324,6 @@ class NpuStripe(Command): class DMA(Command): def __init__(self, ps, in_tensor, out_tensor, box): - self.cmdtype = CommandType.DMA self.ps = ps self.in_tensor = in_tensor self.out_tensor = out_tensor diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 7db4931d..9e0ed010 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -48,7 +48,6 @@ from .data_type import DataType from .debug_database import DebugDatabase from .high_level_command_stream import Box from .high_level_command_stream import Command -from .high_level_command_stream import CommandType from .high_level_command_stream import DMA from .high_level_command_stream import NpuStripe from .operation import NpuBlockType @@ -56,7 +55,6 @@ from .operation import Op from .operation import Operation from .register_command_stream_generator import generate_command_stream from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM -from .register_command_stream_util import is_dma_op from .register_command_stream_util import to_npu_kernel from .register_command_stream_util import UNARY_ELEMWISE_OPS from .tensor import MemType @@ -168,7 +166,7 @@ def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int: else: base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor - return int(base_ptr_idx_map[tens.mem_type]) + return base_ptr_idx_map[tens.mem_type].value def get_upscale(op: Operation) -> NpuResamplingMode: @@ -461,9 +459,10 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation: def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation: """Converts the high level command to NpuOperation""" - if cmd.cmdtype == CommandType.DMA: + npu_op: NpuOperation + if isinstance(cmd, DMA): npu_op = create_dma_op(cmd, arch) - elif cmd.cmdtype == CommandType.NpuStripe: + elif isinstance(cmd, NpuStripe): npu_block_type = cmd.ps.primary_op.type.npu_block_type if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct): npu_op = create_npu_conv2d_op(cmd, arch) @@ -475,8 +474,6 @@ def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOp npu_op = create_npu_elementwise_op(cmd, arch) else: assert 0, f"Unknown command type {npu_block_type}" - # add a link to the high level command for debugging purposes - npu_op.cmd = cmd return npu_op @@ -486,7 +483,7 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): npu_op_list = [] npu_op_to_cmd = dict() # map from npu op to high level command for cmd in sg.high_level_command_stream: - if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default: + if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default: print("Warning: Skipping register command stream generation for", cmd.ps) else: npu_op = convert_command_to_npu_op(cmd, arch) @@ -498,8 +495,8 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): def add_to_debug_db(npu_op: NpuOperation, offset: int): """Adds info to the debug database""" - if not is_dma_op(npu_op): + if not isinstance(npu_op, NpuDmaOperation): cmd = npu_op_to_cmd[npu_op] DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op) - sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db) + sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd) diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py index 8e28b953..8a23b51d 100644 --- a/ethosu/vela/lut.py +++ b/ethosu/vela/lut.py @@ -20,7 +20,8 @@ import uuid import numpy as np from . import numeric_util -from .high_level_command_stream import CommandType +from .high_level_command_stream import DMA +from .high_level_command_stream import NpuStripe from .tensor import create_const_tensor from .tensor import create_equivalence_id from .tensor import TensorPurpose @@ -101,11 +102,11 @@ def optimize_high_level_cmd_stream(sg, arch): lut_start = arch.shram_lut_address lut_end = lut_start + arch.shram_lut_size for cmd in sg.high_level_command_stream: - if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0: + if isinstance(cmd, NpuStripe) and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0: # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not lut_state = LUTState() - if cmd.cmdtype != CommandType.DMA or cmd.out_tensor.purpose != TensorPurpose.LUT: + if not isinstance(cmd, DMA) or cmd.out_tensor.purpose != TensorPurpose.LUT: # Non-LUT operation; leave untouched cmd_stream.append(cmd) continue diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 45fae217..32cba365 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -18,10 +18,17 @@ import copy from collections import namedtuple from enum import Enum +from typing import Any +from typing import Dict +from typing import List from typing import Optional +from typing import TYPE_CHECKING from .numeric_util import full_shape +if TYPE_CHECKING: + from .tensor import Tensor + PointXY = namedtuple("PointXY", "x y") PointXYZ = namedtuple("PointXYZ", "x y z") @@ -392,9 +399,9 @@ class Operation: def __init__(self, op_type: Op, name: str): self.type = op_type self.name = name - self.attrs = {} - self.inputs = [] - self.outputs = [] + self.attrs: Dict[str, Any] = {} + self.inputs: List[Tensor] = [] + self.outputs: List[Tensor] = [] self.flops = 0 self.run_on_npu = True # Fused activation function. If not none: operator code. diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index d4947b1a..fa56d353 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -20,6 +20,7 @@ from collections import defaultdict from enum import Enum from enum import IntEnum +from typing import Dict from typing import List from typing import Optional @@ -33,6 +34,7 @@ from .api import NpuAddressRange from .api import NpuBlockOperation from .api import NpuBlockTraversal from .api import NpuConv2DOperation +from .api import NpuConvDepthWiseOperation from .api import NpuDataType from .api import NpuDmaOperation from .api import NpuElementWiseOp @@ -68,13 +70,13 @@ from .numeric_util import quantise_float32 from .numeric_util import round_away_zero from .numeric_util import round_up_to_int from .operation import NpuBlockType +from .range_set import MemoryAccessSet from .register_command_stream_util import calc_blockdep from .register_command_stream_util import get_dma_memory_accesses from .register_command_stream_util import get_op_memory_accesses from .register_command_stream_util import get_strides from .register_command_stream_util import get_wait_dependency from .register_command_stream_util import has_ifm2 -from .register_command_stream_util import is_dma_op from .register_command_stream_util import to_kernel from .register_command_stream_util import UNARY_ELEMWISE_OPS from .register_command_stream_util import Watermark @@ -549,15 +551,13 @@ def generate_shram_registers_non_elementwise(emit: CommandStreamEmitter, shared_ def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation: """Creates shared buffer allocation for the given operation""" - op_type = npu_op.op_type - block_type = NpuBlockType.Default - if op_type == NpuOperationType.Conv2D: + if isinstance(npu_op, NpuConv2DOperation): block_type = NpuBlockType.ConvolutionMxN - elif op_type == NpuOperationType.ConvDepthWise: + elif isinstance(npu_op, NpuConvDepthWiseOperation): block_type = NpuBlockType.ConvolutionDepthWise - elif op_type == NpuOperationType.Pooling: + elif isinstance(npu_op, NpuPoolingOperation): block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling - elif op_type == NpuOperationType.ElementWise: + elif isinstance(npu_op, NpuElementWiseOperation): block_type = NpuBlockType.ElementWise else: assert 0, "Unsupported operation" @@ -599,7 +599,7 @@ def generate_common( generate_activation(emit, npu_op.activation, npu_op.ofm) shared_buffer = create_shared_buffer(npu_op, arch) generate_block_config(emit, npu_op, arch, shared_buffer) - if npu_op.op_type == NpuOperationType.ElementWise: + if isinstance(npu_op, NpuElementWiseOperation): generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer) else: generate_shram_registers_non_elementwise(emit, shared_buffer) @@ -746,17 +746,20 @@ def print_feature_map(fm: NpuFeatureMap, name: str): print(f" {stride_str}, tiles: w0={t.width_0}, h0={t.height_0}, h1={t.height_1}, base={addresses}") -def print_operation(npu_op: NpuOperation, index: int = 0): - pass_info = f", {npu_op.cmd}" if hasattr(npu_op, "cmd") else "" - if is_dma_op(npu_op): +def print_operation(npu_op: NpuOperation, index: int = 0, cmd=None): + pass_info = f", {cmd}" if cmd else "" + if isinstance(npu_op, NpuOperation) and not isinstance(npu_op, (NpuDmaOperation, NpuBlockOperation)): + print(f"{index} {npu_op.op_type.name}{pass_info}") + return + if isinstance(npu_op, NpuDmaOperation): print(f"{index} DMA_START src={npu_op.src}, dest={npu_op.dest}{pass_info}") return k = None if npu_op.kernel is None else to_kernel(npu_op.kernel) - if npu_op.op_type in (NpuOperationType.Pooling, NpuOperationType.ElementWise): + if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)): print(f"{index} {npu_op.sub_op_type.name} {npu_op.op_type.name}:{pass_info}") else: if ( - npu_op.op_type == NpuOperationType.Conv2D + isinstance(npu_op, NpuConv2DOperation) and k.elements_wh() * k.stride.x * k.stride.y * k.dilation.x * k.dilation.y == 1 ): fc = "FullyConnected " @@ -783,16 +786,19 @@ def print_operation(npu_op: NpuOperation, index: int = 0): if act.op_type != NpuActivationOp.NONE_OR_RELU or act.min is not None or act.max is not None: lut = f", lut index={act.lookup_table_index}" if act.op_type == NpuActivationOp.TABLE_LOOKUP else "" print(f" Activation: {act.op_type.name}, min={act.min}, max={act.max}{lut}") - if npu_op.op_type == NpuOperationType.Conv2D: + if isinstance(npu_op, NpuConv2DOperation): print(f" {npu_op.block_traversal}") bh, bw, bc = npu_op.block_config - rescale = f", rescale={npu_op.rescale}" if hasattr(npu_op, "rescale") else "" + rescale = ( + f", rescale={npu_op.rescale}" if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)) else "" + ) print(f" Block config: h={bh},w={bw},c={bc}, {npu_op.ifm_upscale}, {npu_op.rounding_mode}{rescale}") -def print_operations(npu_op_list: List[NpuOperation]): +def print_operations(npu_op_list: List[NpuOperation], npu_op_to_cmd=None): + npu_op_to_cmd = dict() if npu_op_to_cmd is None else npu_op_to_cmd for index, npu_op in enumerate(npu_op_list): - print_operation(npu_op, index) + print_operation(npu_op, index, npu_op_to_cmd.get(npu_op)) # ------------------------------------------------------------------- @@ -802,16 +808,15 @@ def print_operations(npu_op_list: List[NpuOperation]): def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation): """Generates NPU_OP_* command""" - op_type = npu_op.op_type - if op_type == NpuOperationType.Dma: + if isinstance(npu_op, NpuDmaOperation): emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, npu_op.channel * 16 + npu_op.mode) - elif op_type == NpuOperationType.Conv2D: + elif isinstance(npu_op, NpuConv2DOperation): emit.cmd_do_operation(cmd0.NPU_OP_CONV) - elif op_type == NpuOperationType.ConvDepthWise: + elif isinstance(npu_op, NpuConvDepthWiseOperation): emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE) - elif op_type == NpuOperationType.Pooling: + elif isinstance(npu_op, NpuPoolingOperation): emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_op_map[npu_op.sub_op_type]) - elif op_type == NpuOperationType.ElementWise: + elif isinstance(npu_op, NpuElementWiseOperation): emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param=elementwise_op_map[npu_op.sub_op_type]) else: assert 0, "Unsupported operation" @@ -822,7 +827,9 @@ def generate_conv2d_op(emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, a generate_common(emit, npu_op, npu_op.block_traversal, arch) -def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures): +def generate_conv_depthwise_op( + emit: CommandStreamEmitter, npu_op: NpuConvDepthWiseOperation, arch: ArchitectureFeatures +): """Generates register commands for depthwise convolution operations""" generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch) @@ -880,23 +887,22 @@ def generate_registers_for_op(emit: CommandStreamEmitter, npu_op: NpuOperation, Generates register commands for the given operation, but not the final NPU_OP_... command. Returns the selected block config """ - op_type = npu_op.op_type - if op_type == NpuOperationType.Conv2D: + if isinstance(npu_op, NpuConv2DOperation): generate_conv2d_op(emit, npu_op, arch) - elif op_type == NpuOperationType.ConvDepthWise: + elif isinstance(npu_op, NpuConvDepthWiseOperation): generate_conv_depthwise_op(emit, npu_op, arch) - elif op_type == NpuOperationType.Pooling: + elif isinstance(npu_op, NpuPoolingOperation): generate_pooling_op(emit, npu_op, arch) - elif op_type == NpuOperationType.ElementWise: + elif isinstance(npu_op, NpuElementWiseOperation): generate_elementwise_op(emit, npu_op, arch) - elif op_type == NpuOperationType.Dma: + elif isinstance(npu_op, NpuDmaOperation): generate_dma_op(emit, npu_op) else: assert 0, "Unsupported operation" def generate_command_stream( - npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None, + npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None, npu_op_to_cmd=None ) -> List[int]: """ Generates register commands for the given list of NPU operations. @@ -904,14 +910,16 @@ def generate_command_stream( """ emit = CommandStreamEmitter() if verbose: - print_operations(npu_op_list) + print_operations(npu_op_list, npu_op_to_cmd) # Calculate memory accesses for every operation - memory_accesses = {} + memory_accesses: Dict[NpuOperation, MemoryAccessSet] = {} for npu_op in npu_op_list: - if is_dma_op(npu_op): + if isinstance(npu_op, NpuDmaOperation): memory_accesses[npu_op] = get_dma_memory_accesses(npu_op) - else: + elif isinstance(npu_op, NpuBlockOperation): memory_accesses[npu_op] = get_op_memory_accesses(npu_op, arch) + else: + assert 0, "Invalid operation type" if arch.is_ethos_u65_system: emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1) dep_watermark = Watermark(0, 0) @@ -920,7 +928,7 @@ def generate_command_stream( for op_index, npu_op in enumerate(npu_op_list): dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark) generate_registers_for_op(emit, npu_op, arch) - if not is_dma_op(npu_op): + if not isinstance(npu_op, NpuDmaOperation) and isinstance(npu_op, NpuBlockOperation): # Generate BLOCKDEP blockdep = calc_blockdep(arch, prev_op, npu_op) blockdep = min(blockdep, arch.max_blockdep) @@ -951,12 +959,12 @@ def find_block_configs(npu_op: NpuOperation, npu_accelerator: NpuAccelerator) -> """ Internal implementation of the public facing API for finding block configs. """ - if is_dma_op(npu_op): - return [] - arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator)) - shared_buffer = create_shared_buffer(npu_op, arch) - blocks = find_suitable_block_configs(arch, shared_buffer) - return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks] + if isinstance(npu_op, NpuBlockOperation): + arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator)) + shared_buffer = create_shared_buffer(npu_op, arch) + blocks = find_suitable_block_configs(arch, shared_buffer) + return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks] + return [] def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]: diff --git a/ethosu/vela/register_command_stream_util.py b/ethosu/vela/register_command_stream_util.py index ce49fc29..55fa620c 100644 --- a/ethosu/vela/register_command_stream_util.py +++ b/ethosu/vela/register_command_stream_util.py @@ -68,11 +68,6 @@ def has_ifm2(npu_op: NpuBlockOperation) -> bool: return npu_op.ifm2 is not None and npu_op.ifm2_scalar is None -def is_dma_op(npu_op: NpuOperation) -> bool: - """Checks if op is a DMA operation""" - return npu_op.op_type == NpuOperationType.Dma - - def shape3d_size(shape: NpuShape3D) -> int: return shape.width * shape.height * shape.depth @@ -302,9 +297,9 @@ def get_wait_dependency( prev_access = memory_accesses[prev_op] # Check NPU consuming DMA output - if is_dma_op(prev_op): + if isinstance(prev_op, NpuDmaOperation): if index >= dma_index: - if not is_dma_op(npu_op): + if not isinstance(npu_op, NpuDmaOperation): if (dma_outstanding == -1) and prev_access.conflicts(op_access): dma_outstanding = dma_ops dma_ops += 1 # Count DMA ops in the pipeline @@ -313,7 +308,7 @@ def get_wait_dependency( # Check DMA consuming NPU output else: if index >= npu_index: - if is_dma_op(npu_op) and npu_outstanding == -1 and prev_access.conflicts(op_access): + if isinstance(npu_op, NpuDmaOperation) and npu_outstanding == -1 and prev_access.conflicts(op_access): npu_outstanding = npu_ops npu_ops += 1 # Count NPU ops in the pipeline if npu_ops >= arch.max_outstanding_kernels: |