aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorerik.andersson@arm.com <erik.andersson@arm.com>2021-01-19 11:24:43 +0100
committererik.andersson@arm.com <erik.andersson@arm.com>2021-01-22 13:53:04 +0100
commit606063fb1c37ddb211940b8ad211f31824aaee60 (patch)
treef2038c06e54ca709d6460dfd23c9bb8dab2e0515
parent7b676498c2499428f238f68b0224dc3a8fbcb56e (diff)
downloadethos-u-vela-606063fb1c37ddb211940b8ad211f31824aaee60.tar.gz
MLBEDSW-3736: Replaced placeholder type annotations
Placeholder type annotations have been replaced to their corresponding types. Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com> Change-Id: I017b87174ceefbfa40c53b2bd450d7404b9f4f30
-rw-r--r--ethosu/vela/debug_database.py40
1 files changed, 24 insertions, 16 deletions
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
index 6964808..006348c 100644
--- a/ethosu/vela/debug_database.py
+++ b/ethosu/vela/debug_database.py
@@ -18,38 +18,37 @@ import io
from typing import Any
from typing import Dict
from typing import List
+from typing import Tuple
+from typing import Union
import lxml.etree as xml
from . import numeric_util
from .operation import Operation
-UntypedDict = Dict[Any, Any]
-UntypedList = List[Any]
-
class DebugDatabase:
NULLREF = -1
show_warnings = False
SOURCE_TABLE = "source"
- _sourceUID: UntypedDict = {}
+ _sourceUID: Dict[Any, int] = {}
_sourceHeaders = ["id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
- _sourceTable: UntypedList = []
+ _sourceTable: List[List[Union[float, int, str]]] = []
OPTIMISED_TABLE = "optimised"
- _optimisedUID: UntypedDict = {}
+ _optimisedUID: Dict[Any, Tuple[int, int]] = {}
_optimisedHeaders = ["id", "source_id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"]
- _optimisedTable: UntypedList = []
+ _optimisedTable: List[List[Union[float, int, str]]] = []
QUEUE_TABLE = "queue"
_queueHeaders = ["offset", "cmdstream_id", "optimised_id"]
- _queueTable: UntypedList = []
+ _queueTable: List[List[int]] = []
STREAM_TABLE = "cmdstream"
- _streamUID: UntypedDict = {}
+ _streamUID: Dict[Any, int] = {}
_streamHeaders = ["id", "file_offset"]
- _streamTable: UntypedList = []
+ _streamTable: List[List[int]] = []
@classmethod
def add_source(cls, op: Operation):
@@ -58,7 +57,7 @@ class DebugDatabase:
cls._sourceUID[op] = uid
ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1)
cls._sourceTable.append(
- [uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
+ [uid, str(op.type), op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
)
@classmethod
@@ -80,7 +79,16 @@ class DebugDatabase:
cls._optimisedUID[op] = (uid, src_uid)
ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1)
cls._optimisedTable.append(
- [uid, src_uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
+ [
+ uid,
+ src_uid,
+ str(op.type),
+ op.kernel.width,
+ op.kernel.height,
+ ofm_shape[-2],
+ ofm_shape[-3],
+ ofm_shape[-1],
+ ]
)
@classmethod
@@ -91,20 +99,20 @@ class DebugDatabase:
return uid
@classmethod
- def set_stream_offset(cls, key, file_offset):
+ def set_stream_offset(cls, key, file_offset: int):
assert key in cls._streamUID
uid = cls._streamUID[key]
cls._streamTable.append([uid, file_offset])
@classmethod
- def add_command(cls, stream_id, offset, op: Operation):
+ def add_command(cls, stream_id: int, offset: int, op: Operation):
assert stream_id < len(cls._streamUID)
assert op in cls._optimisedUID, "Optimised operator must exist before code generation"
optimised_id = cls._optimisedUID[op][0]
cls._queueTable.append([offset, stream_id, optimised_id])
@classmethod
- def _write_table(cls, root, name, headers, table):
+ def _write_table(cls, root: xml.Element, name: str, headers: List[str], table):
# Convert table to CSV
out = io.StringIO()
writer = csv.writer(out, quoting=csv.QUOTE_NONNUMERIC)
@@ -116,7 +124,7 @@ class DebugDatabase:
table.text = xml.CDATA(out.getvalue())
@classmethod
- def write(cls, file_path, input_file, output_file):
+ def write(cls, file_path: str, input_file: str, output_file: str):
root = xml.Element("debug", {"source": input_file, "optimised": output_file})
cls._write_table(root, cls.SOURCE_TABLE, cls._sourceHeaders, cls._sourceTable)