diff options
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r-- | ethosu/vela/high_level_command_to_npu_op.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 7db4931d..9e0ed010 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -48,7 +48,6 @@ from .data_type import DataType from .debug_database import DebugDatabase from .high_level_command_stream import Box from .high_level_command_stream import Command -from .high_level_command_stream import CommandType from .high_level_command_stream import DMA from .high_level_command_stream import NpuStripe from .operation import NpuBlockType @@ -56,7 +55,6 @@ from .operation import Op from .operation import Operation from .register_command_stream_generator import generate_command_stream from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM -from .register_command_stream_util import is_dma_op from .register_command_stream_util import to_npu_kernel from .register_command_stream_util import UNARY_ELEMWISE_OPS from .tensor import MemType @@ -168,7 +166,7 @@ def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int: else: base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor - return int(base_ptr_idx_map[tens.mem_type]) + return base_ptr_idx_map[tens.mem_type].value def get_upscale(op: Operation) -> NpuResamplingMode: @@ -461,9 +459,10 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation: def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation: """Converts the high level command to NpuOperation""" - if cmd.cmdtype == CommandType.DMA: + npu_op: NpuOperation + if isinstance(cmd, DMA): npu_op = create_dma_op(cmd, arch) - elif cmd.cmdtype == CommandType.NpuStripe: + elif isinstance(cmd, NpuStripe): npu_block_type = cmd.ps.primary_op.type.npu_block_type if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct): npu_op = create_npu_conv2d_op(cmd, arch) @@ -475,8 +474,6 @@ def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOp npu_op = create_npu_elementwise_op(cmd, arch) else: assert 0, f"Unknown command type {npu_block_type}" - # add a link to the high level command for debugging purposes - npu_op.cmd = cmd return npu_op @@ -486,7 +483,7 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): npu_op_list = [] npu_op_to_cmd = dict() # map from npu op to high level command for cmd in sg.high_level_command_stream: - if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default: + if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default: print("Warning: Skipping register command stream generation for", cmd.ps) else: npu_op = convert_command_to_npu_op(cmd, arch) @@ -498,8 +495,8 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False): def add_to_debug_db(npu_op: NpuOperation, offset: int): """Adds info to the debug database""" - if not is_dma_op(npu_op): + if not isinstance(npu_op, NpuDmaOperation): cmd = npu_op_to_cmd[npu_op] DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op) - sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db) + sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd) |