aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/high_level_command_to_npu_op.py
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2020-12-08 17:56:44 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-14 07:00:24 +0000
commit9b43f846b144d39bfb0cf16853bf6901c74b6672 (patch)
treea530dce790bb8e54dad009e11ca4d49d54b52b1d /ethosu/vela/high_level_command_to_npu_op.py
parent94457b175b8646bce089c9924e99686587de8992 (diff)
downloadethos-u-vela-9b43f846b144d39bfb0cf16853bf6901c74b6672.tar.gz
MLBEDSW-3653: Fix type errors in annotated files
This commit corrects a number of type errors reported by mypy and refactors some parts of the code which are no longer necessary after making adjustments to satisfy mypy. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: I16b880b228e57f2a92fb8936f53e94886e0f9f44
Diffstat (limited to 'ethosu/vela/high_level_command_to_npu_op.py')
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 7db4931d..9e0ed010 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -48,7 +48,6 @@ from .data_type import DataType
from .debug_database import DebugDatabase
from .high_level_command_stream import Box
from .high_level_command_stream import Command
-from .high_level_command_stream import CommandType
from .high_level_command_stream import DMA
from .high_level_command_stream import NpuStripe
from .operation import NpuBlockType
@@ -56,7 +55,6 @@ from .operation import Op
from .operation import Operation
from .register_command_stream_generator import generate_command_stream
from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
-from .register_command_stream_util import is_dma_op
from .register_command_stream_util import to_npu_kernel
from .register_command_stream_util import UNARY_ELEMWISE_OPS
from .tensor import MemType
@@ -168,7 +166,7 @@ def get_region(tens: Tensor, arch: ArchitectureFeatures) -> int:
else:
base_ptr_idx_map[MemType.Scratch_fast] = BasePointerIndex.ScratchTensor
- return int(base_ptr_idx_map[tens.mem_type])
+ return base_ptr_idx_map[tens.mem_type].value
def get_upscale(op: Operation) -> NpuResamplingMode:
@@ -461,9 +459,10 @@ def create_dma_op(cmd: DMA, arch: ArchitectureFeatures) -> NpuDmaOperation:
def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOperation:
"""Converts the high level command to NpuOperation"""
- if cmd.cmdtype == CommandType.DMA:
+ npu_op: NpuOperation
+ if isinstance(cmd, DMA):
npu_op = create_dma_op(cmd, arch)
- elif cmd.cmdtype == CommandType.NpuStripe:
+ elif isinstance(cmd, NpuStripe):
npu_block_type = cmd.ps.primary_op.type.npu_block_type
if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
npu_op = create_npu_conv2d_op(cmd, arch)
@@ -475,8 +474,6 @@ def convert_command_to_npu_op(cmd: Command, arch: ArchitectureFeatures) -> NpuOp
npu_op = create_npu_elementwise_op(cmd, arch)
else:
assert 0, f"Unknown command type {npu_block_type}"
- # add a link to the high level command for debugging purposes
- npu_op.cmd = cmd
return npu_op
@@ -486,7 +483,7 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
npu_op_list = []
npu_op_to_cmd = dict() # map from npu op to high level command
for cmd in sg.high_level_command_stream:
- if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.npu_block_type == NpuBlockType.Default:
+ if isinstance(cmd, NpuStripe) and cmd.ps.npu_block_type == NpuBlockType.Default:
print("Warning: Skipping register command stream generation for", cmd.ps)
else:
npu_op = convert_command_to_npu_op(cmd, arch)
@@ -498,8 +495,8 @@ def generate_register_command_stream_for_sg(nng, sg, arch, verbose=False):
def add_to_debug_db(npu_op: NpuOperation, offset: int):
"""Adds info to the debug database"""
- if not is_dma_op(npu_op):
+ if not isinstance(npu_op, NpuDmaOperation):
cmd = npu_op_to_cmd[npu_op]
DebugDatabase.add_command(stream_id, offset, cmd.ps.primary_op)
- sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db)
+ sg.register_command_stream = generate_command_stream(npu_op_list, arch, verbose, add_to_debug_db, npu_op_to_cmd)