diff options
author | Dwight Lidman <dwight.lidman@arm.com> | 2020-12-08 17:56:44 +0100 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-12-14 07:00:24 +0000 |
commit | 9b43f846b144d39bfb0cf16853bf6901c74b6672 (patch) | |
tree | a530dce790bb8e54dad009e11ca4d49d54b52b1d /ethosu/vela/register_command_stream_generator.py | |
parent | 94457b175b8646bce089c9924e99686587de8992 (diff) | |
download | ethos-u-vela-9b43f846b144d39bfb0cf16853bf6901c74b6672.tar.gz |
MLBEDSW-3653: Fix type errors in annotated files
This commit corrects a number of type errors
reported by mypy and refactors some parts of
the code which are no longer necessary after
making adjustments to satisfy mypy.
Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I16b880b228e57f2a92fb8936f53e94886e0f9f44
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r-- | ethosu/vela/register_command_stream_generator.py | 92 |
1 files changed, 50 insertions, 42 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index d4947b1a..fa56d353 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -20,6 +20,7 @@ from collections import defaultdict from enum import Enum from enum import IntEnum +from typing import Dict from typing import List from typing import Optional @@ -33,6 +34,7 @@ from .api import NpuAddressRange from .api import NpuBlockOperation from .api import NpuBlockTraversal from .api import NpuConv2DOperation +from .api import NpuConvDepthWiseOperation from .api import NpuDataType from .api import NpuDmaOperation from .api import NpuElementWiseOp @@ -68,13 +70,13 @@ from .numeric_util import quantise_float32 from .numeric_util import round_away_zero from .numeric_util import round_up_to_int from .operation import NpuBlockType +from .range_set import MemoryAccessSet from .register_command_stream_util import calc_blockdep from .register_command_stream_util import get_dma_memory_accesses from .register_command_stream_util import get_op_memory_accesses from .register_command_stream_util import get_strides from .register_command_stream_util import get_wait_dependency from .register_command_stream_util import has_ifm2 -from .register_command_stream_util import is_dma_op from .register_command_stream_util import to_kernel from .register_command_stream_util import UNARY_ELEMWISE_OPS from .register_command_stream_util import Watermark @@ -549,15 +551,13 @@ def generate_shram_registers_non_elementwise(emit: CommandStreamEmitter, shared_ def create_shared_buffer(npu_op: NpuBlockOperation, arch: ArchitectureFeatures) -> SharedBufferAllocation: """Creates shared buffer allocation for the given operation""" - op_type = npu_op.op_type - block_type = NpuBlockType.Default - if op_type == NpuOperationType.Conv2D: + if isinstance(npu_op, NpuConv2DOperation): block_type = NpuBlockType.ConvolutionMxN - elif op_type == NpuOperationType.ConvDepthWise: + elif isinstance(npu_op, NpuConvDepthWiseOperation): block_type = NpuBlockType.ConvolutionDepthWise - elif op_type == NpuOperationType.Pooling: + elif isinstance(npu_op, NpuPoolingOperation): block_type = NpuBlockType.ReduceSum if npu_op.sub_op_type == NpuPoolingOp.REDUCE_SUM else NpuBlockType.Pooling - elif op_type == NpuOperationType.ElementWise: + elif isinstance(npu_op, NpuElementWiseOperation): block_type = NpuBlockType.ElementWise else: assert 0, "Unsupported operation" @@ -599,7 +599,7 @@ def generate_common( generate_activation(emit, npu_op.activation, npu_op.ofm) shared_buffer = create_shared_buffer(npu_op, arch) generate_block_config(emit, npu_op, arch, shared_buffer) - if npu_op.op_type == NpuOperationType.ElementWise: + if isinstance(npu_op, NpuElementWiseOperation): generate_shram_registers_elementwise(emit, npu_op, arch, shared_buffer) else: generate_shram_registers_non_elementwise(emit, shared_buffer) @@ -746,17 +746,20 @@ def print_feature_map(fm: NpuFeatureMap, name: str): print(f" {stride_str}, tiles: w0={t.width_0}, h0={t.height_0}, h1={t.height_1}, base={addresses}") -def print_operation(npu_op: NpuOperation, index: int = 0): - pass_info = f", {npu_op.cmd}" if hasattr(npu_op, "cmd") else "" - if is_dma_op(npu_op): +def print_operation(npu_op: NpuOperation, index: int = 0, cmd=None): + pass_info = f", {cmd}" if cmd else "" + if isinstance(npu_op, NpuOperation) and not isinstance(npu_op, (NpuDmaOperation, NpuBlockOperation)): + print(f"{index} {npu_op.op_type.name}{pass_info}") + return + if isinstance(npu_op, NpuDmaOperation): print(f"{index} DMA_START src={npu_op.src}, dest={npu_op.dest}{pass_info}") return k = None if npu_op.kernel is None else to_kernel(npu_op.kernel) - if npu_op.op_type in (NpuOperationType.Pooling, NpuOperationType.ElementWise): + if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)): print(f"{index} {npu_op.sub_op_type.name} {npu_op.op_type.name}:{pass_info}") else: if ( - npu_op.op_type == NpuOperationType.Conv2D + isinstance(npu_op, NpuConv2DOperation) and k.elements_wh() * k.stride.x * k.stride.y * k.dilation.x * k.dilation.y == 1 ): fc = "FullyConnected " @@ -783,16 +786,19 @@ def print_operation(npu_op: NpuOperation, index: int = 0): if act.op_type != NpuActivationOp.NONE_OR_RELU or act.min is not None or act.max is not None: lut = f", lut index={act.lookup_table_index}" if act.op_type == NpuActivationOp.TABLE_LOOKUP else "" print(f" Activation: {act.op_type.name}, min={act.min}, max={act.max}{lut}") - if npu_op.op_type == NpuOperationType.Conv2D: + if isinstance(npu_op, NpuConv2DOperation): print(f" {npu_op.block_traversal}") bh, bw, bc = npu_op.block_config - rescale = f", rescale={npu_op.rescale}" if hasattr(npu_op, "rescale") else "" + rescale = ( + f", rescale={npu_op.rescale}" if isinstance(npu_op, (NpuPoolingOperation, NpuElementWiseOperation)) else "" + ) print(f" Block config: h={bh},w={bw},c={bc}, {npu_op.ifm_upscale}, {npu_op.rounding_mode}{rescale}") -def print_operations(npu_op_list: List[NpuOperation]): +def print_operations(npu_op_list: List[NpuOperation], npu_op_to_cmd=None): + npu_op_to_cmd = dict() if npu_op_to_cmd is None else npu_op_to_cmd for index, npu_op in enumerate(npu_op_list): - print_operation(npu_op, index) + print_operation(npu_op, index, npu_op_to_cmd.get(npu_op)) # ------------------------------------------------------------------- @@ -802,16 +808,15 @@ def print_operations(npu_op_list: List[NpuOperation]): def generate_operation_code(emit: CommandStreamEmitter, npu_op: NpuOperation): """Generates NPU_OP_* command""" - op_type = npu_op.op_type - if op_type == NpuOperationType.Dma: + if isinstance(npu_op, NpuDmaOperation): emit.cmd_do_operation(cmd0.NPU_OP_DMA_START, npu_op.channel * 16 + npu_op.mode) - elif op_type == NpuOperationType.Conv2D: + elif isinstance(npu_op, NpuConv2DOperation): emit.cmd_do_operation(cmd0.NPU_OP_CONV) - elif op_type == NpuOperationType.ConvDepthWise: + elif isinstance(npu_op, NpuConvDepthWiseOperation): emit.cmd_do_operation(cmd0.NPU_OP_DEPTHWISE) - elif op_type == NpuOperationType.Pooling: + elif isinstance(npu_op, NpuPoolingOperation): emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_op_map[npu_op.sub_op_type]) - elif op_type == NpuOperationType.ElementWise: + elif isinstance(npu_op, NpuElementWiseOperation): emit.cmd_do_operation(cmd0.NPU_OP_ELEMENTWISE, param=elementwise_op_map[npu_op.sub_op_type]) else: assert 0, "Unsupported operation" @@ -822,7 +827,9 @@ def generate_conv2d_op(emit: CommandStreamEmitter, npu_op: NpuConv2DOperation, a generate_common(emit, npu_op, npu_op.block_traversal, arch) -def generate_conv_depthwise_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation, arch: ArchitectureFeatures): +def generate_conv_depthwise_op( + emit: CommandStreamEmitter, npu_op: NpuConvDepthWiseOperation, arch: ArchitectureFeatures +): """Generates register commands for depthwise convolution operations""" generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch) @@ -880,23 +887,22 @@ def generate_registers_for_op(emit: CommandStreamEmitter, npu_op: NpuOperation, Generates register commands for the given operation, but not the final NPU_OP_... command. Returns the selected block config """ - op_type = npu_op.op_type - if op_type == NpuOperationType.Conv2D: + if isinstance(npu_op, NpuConv2DOperation): generate_conv2d_op(emit, npu_op, arch) - elif op_type == NpuOperationType.ConvDepthWise: + elif isinstance(npu_op, NpuConvDepthWiseOperation): generate_conv_depthwise_op(emit, npu_op, arch) - elif op_type == NpuOperationType.Pooling: + elif isinstance(npu_op, NpuPoolingOperation): generate_pooling_op(emit, npu_op, arch) - elif op_type == NpuOperationType.ElementWise: + elif isinstance(npu_op, NpuElementWiseOperation): generate_elementwise_op(emit, npu_op, arch) - elif op_type == NpuOperationType.Dma: + elif isinstance(npu_op, NpuDmaOperation): generate_dma_op(emit, npu_op) else: assert 0, "Unsupported operation" def generate_command_stream( - npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None, + npu_op_list: List[NpuOperation], arch: ArchitectureFeatures, verbose: bool, add_to_debug_db=None, npu_op_to_cmd=None ) -> List[int]: """ Generates register commands for the given list of NPU operations. @@ -904,14 +910,16 @@ def generate_command_stream( """ emit = CommandStreamEmitter() if verbose: - print_operations(npu_op_list) + print_operations(npu_op_list, npu_op_to_cmd) # Calculate memory accesses for every operation - memory_accesses = {} + memory_accesses: Dict[NpuOperation, MemoryAccessSet] = {} for npu_op in npu_op_list: - if is_dma_op(npu_op): + if isinstance(npu_op, NpuDmaOperation): memory_accesses[npu_op] = get_dma_memory_accesses(npu_op) - else: + elif isinstance(npu_op, NpuBlockOperation): memory_accesses[npu_op] = get_op_memory_accesses(npu_op, arch) + else: + assert 0, "Invalid operation type" if arch.is_ethos_u65_system: emit.cmd0_with_param(cmd0.NPU_SET_PARALLEL_MODE, arch.ncores - 1) dep_watermark = Watermark(0, 0) @@ -920,7 +928,7 @@ def generate_command_stream( for op_index, npu_op in enumerate(npu_op_list): dep_watermark, cmd_waits = get_wait_dependency(arch, npu_op_list, memory_accesses, op_index, dep_watermark) generate_registers_for_op(emit, npu_op, arch) - if not is_dma_op(npu_op): + if not isinstance(npu_op, NpuDmaOperation) and isinstance(npu_op, NpuBlockOperation): # Generate BLOCKDEP blockdep = calc_blockdep(arch, prev_op, npu_op) blockdep = min(blockdep, arch.max_blockdep) @@ -951,12 +959,12 @@ def find_block_configs(npu_op: NpuOperation, npu_accelerator: NpuAccelerator) -> """ Internal implementation of the public facing API for finding block configs. """ - if is_dma_op(npu_op): - return [] - arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator)) - shared_buffer = create_shared_buffer(npu_op, arch) - blocks = find_suitable_block_configs(arch, shared_buffer) - return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks] + if isinstance(npu_op, NpuBlockOperation): + arch = create_default_arch(Accelerator.from_npu_accelerator(npu_accelerator)) + shared_buffer = create_shared_buffer(npu_op, arch) + blocks = find_suitable_block_configs(arch, shared_buffer) + return [NpuShape3D(height=block[0], width=block[1], depth=block[3]) for block in blocks] + return [] def generate_register_command_stream(npu_op_list: List[NpuOperation], npu_accelerator: NpuAccelerator) -> List[int]: |