aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41 /tests
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'tests')
-rw-r--r--tests/test_backend_application.py5
-rw-r--r--tests/test_backend_common.py16
-rw-r--r--tests/test_backend_fs.py5
-rw-r--r--tests/test_backend_manager.py45
-rw-r--r--tests/test_backend_output_consumer.py9
-rw-r--r--tests/test_backend_system.py12
-rw-r--r--tests/test_cli_commands.py7
-rw-r--r--tests/test_cli_config.py7
-rw-r--r--tests/test_cli_helpers.py14
-rw-r--r--tests/test_cli_logging.py5
-rw-r--r--tests/test_cli_main.py7
-rw-r--r--tests/test_cli_options.py9
-rw-r--r--tests/test_core_advice_generation.py4
-rw-r--r--tests/test_core_reporting.py11
-rw-r--r--tests/test_devices_ethosu_advice_generation.py6
-rw-r--r--tests/test_devices_ethosu_config.py5
-rw-r--r--tests/test_devices_ethosu_data_analysis.py4
-rw-r--r--tests/test_devices_ethosu_reporters.py16
-rw-r--r--tests/test_devices_tosa_advice_generation.py4
-rw-r--r--tests/test_devices_tosa_data_analysis.py4
-rw-r--r--tests/test_devices_tosa_operators.py5
-rw-r--r--tests/test_nn_tensorflow_optimizations_clustering.py12
-rw-r--r--tests/test_nn_tensorflow_optimizations_pruning.py8
-rw-r--r--tests/test_nn_tensorflow_optimizations_select.py6
-rw-r--r--tests/test_nn_tensorflow_tflite_metrics.py5
-rw-r--r--tests/test_tools_metadata_common.py12
-rw-r--r--tests/test_tools_metadata_corstone.py20
-rw-r--r--tests/test_utils_console.py6
-rw-r--r--tests/test_utils_download.py11
-rw-r--r--tests/test_utils_logging.py9
-rw-r--r--tests/test_utils_types.py9
-rw-r--r--tests/utils/common.py4
32 files changed, 152 insertions, 150 deletions
diff --git a/tests/test_backend_application.py b/tests/test_backend_application.py
index 6860ecb..9606802 100644
--- a/tests/test_backend_application.py
+++ b/tests/test_backend_application.py
@@ -2,11 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-self-use
"""Tests for the application backend."""
+from __future__ import annotations
+
from collections import Counter
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import List
from unittest.mock import MagicMock
import pytest
@@ -289,7 +290,7 @@ class TestApplication:
),
)
def test_remove_unused_params(
- self, config: ApplicationConfig, expected_params: List[Param]
+ self, config: ApplicationConfig, expected_params: list[Param]
) -> None:
"""Test mod remove_unused_parameter."""
application = Application(config)
diff --git a/tests/test_backend_common.py b/tests/test_backend_common.py
index 0533ef6..d11261e 100644
--- a/tests/test_backend_common.py
+++ b/tests/test_backend_common.py
@@ -2,16 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-self-use,protected-access
"""Tests for the common backend module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
-from typing import Dict
from typing import IO
from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
from unittest.mock import MagicMock
import pytest
@@ -62,7 +60,7 @@ def test_load_config(
) -> None:
"""Test load_config."""
with expected_exception:
- configs: List[Optional[Union[Path, IO[bytes]]]] = (
+ configs: list[Path | IO[bytes] | None] = (
[None]
if not filename
else [
@@ -283,8 +281,8 @@ class TestBackend:
def test_resolved_parameters(
self,
class_: type,
- config: Dict,
- expected_output: List[Tuple[Optional[str], Param]],
+ config: dict,
+ expected_output: list[tuple[str | None, Param]],
) -> None:
"""Test command building."""
backend = class_(config)
@@ -343,7 +341,7 @@ class TestBackend:
],
)
def test__parse_raw_parameter(
- self, input_param: str, expected: Tuple[str, Optional[str]]
+ self, input_param: str, expected: tuple[str, str | None]
) -> None:
"""Test internal method of parsing a single raw parameter."""
assert parse_raw_parameter(input_param) == expected
@@ -476,7 +474,7 @@ class TestCommand:
],
],
)
- def test_validate_params(self, params: List[Param], expected_error: Any) -> None:
+ def test_validate_params(self, params: list[Param], expected_error: Any) -> None:
"""Test command validation function."""
with expected_error:
Command([], params)
diff --git a/tests/test_backend_fs.py b/tests/test_backend_fs.py
index 7423222..21226a9 100644
--- a/tests/test_backend_fs.py
+++ b/tests/test_backend_fs.py
@@ -2,10 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=no-self-use
"""Module for testing fs.py."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Union
from unittest.mock import MagicMock
import pytest
@@ -108,7 +109,7 @@ def test_recreate_directory(tmpdir: Any) -> None:
def write_to_file(
- write_directory: Any, write_mode: str, write_text: Union[str, bytes]
+ write_directory: Any, write_mode: str, write_text: str | bytes
) -> Path:
"""Write some text to a temporary test file."""
tmpdir_path = Path(write_directory)
diff --git a/tests/test_backend_manager.py b/tests/test_backend_manager.py
index 1b5fea1..a1e9198 100644
--- a/tests/test_backend_manager.py
+++ b/tests/test_backend_manager.py
@@ -1,16 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module backend/manager."""
+from __future__ import annotations
+
import base64
import json
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Set
-from typing import Tuple
from unittest.mock import MagicMock
from unittest.mock import PropertyMock
@@ -35,7 +32,7 @@ from mlia.backend.output_consumer import Base64OutputConsumer
from mlia.backend.system import get_system
-def _mock_encode_b64(data: Dict[str, int]) -> str:
+def _mock_encode_b64(data: dict[str, int]) -> str:
"""
Encode the given data into a mock base64-encoded string of JSON.
@@ -138,7 +135,7 @@ def _mock_encode_b64(data: Dict[str, int]) -> str:
],
)
def test_generic_inference_output_parser(
- data: Dict[str, int], is_ready: bool, result: Dict, missed_keys: Set[str]
+ data: dict[str, int], is_ready: bool, result: dict, missed_keys: set[str]
) -> None:
"""Test generic runner output parser."""
parser = GenericInferenceOutputParser()
@@ -157,8 +154,8 @@ class TestBackendRunner:
@staticmethod
def _setup_backends(
monkeypatch: pytest.MonkeyPatch,
- available_systems: Optional[List[str]] = None,
- available_apps: Optional[List[str]] = None,
+ available_systems: list[str] | None = None,
+ available_apps: list[str] | None = None,
) -> None:
"""Set up backend metadata."""
@@ -196,7 +193,7 @@ class TestBackendRunner:
)
def test_is_system_installed(
self,
- available_systems: List,
+ available_systems: list,
system: str,
installed: bool,
monkeypatch: pytest.MonkeyPatch,
@@ -217,8 +214,8 @@ class TestBackendRunner:
)
def test_installed_systems(
self,
- available_systems: List[str],
- systems: List[str],
+ available_systems: list[str],
+ systems: list[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test method installed_systems."""
@@ -250,8 +247,8 @@ class TestBackendRunner:
)
def test_systems_installed(
self,
- available_systems: List[str],
- systems: List[str],
+ available_systems: list[str],
+ systems: list[str],
expected_result: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -274,8 +271,8 @@ class TestBackendRunner:
)
def test_applications_installed(
self,
- available_apps: List[str],
- applications: List[str],
+ available_apps: list[str],
+ applications: list[str],
expected_result: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -297,8 +294,8 @@ class TestBackendRunner:
)
def test_get_installed_applications(
self,
- available_apps: List[str],
- applications: List[str],
+ available_apps: list[str],
+ applications: list[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test method get_installed_applications."""
@@ -337,7 +334,7 @@ class TestBackendRunner:
)
def test_is_application_installed(
self,
- available_apps: List[str],
+ available_apps: list[str],
application: str,
installed: bool,
monkeypatch: pytest.MonkeyPatch,
@@ -377,7 +374,7 @@ class TestBackendRunner:
def test_run_application_local(
monkeypatch: pytest.MonkeyPatch,
execution_params: ExecutionParams,
- expected_command: List[str],
+ expected_command: list[str],
) -> None:
"""Test method run_application with local systems."""
run_app = MagicMock(wraps=run_application)
@@ -491,8 +488,8 @@ class TestBackendRunner:
)
def test_estimate_performance(
device: DeviceInfo,
- system: Tuple[str, bool],
- application: Tuple[str, bool],
+ system: tuple[str, bool],
+ application: tuple[str, bool],
backend: str,
expected_error: Any,
test_tflite_model: Path,
@@ -588,7 +585,7 @@ def test_estimate_performance_invalid_output(
)
-def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock:
+def create_mock_process(stdout: list[str], stderr: list[str]) -> MagicMock:
"""Mock underlying process."""
mock_process = MagicMock()
mock_process.poll.return_value = 0
@@ -597,7 +594,7 @@ def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock:
return mock_process
-def create_mock_context(stdout: List[str]) -> ExecutionContext:
+def create_mock_context(stdout: list[str]) -> ExecutionContext:
"""Mock ExecutionContext."""
ctx = ExecutionContext(
app=get_application("application_1")[0],
diff --git a/tests/test_backend_output_consumer.py b/tests/test_backend_output_consumer.py
index 881112e..2ecb07f 100644
--- a/tests/test_backend_output_consumer.py
+++ b/tests/test_backend_output_consumer.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the output parsing."""
+from __future__ import annotations
+
import base64
import json
from typing import Any
-from typing import Dict
import pytest
@@ -42,7 +43,7 @@ REGEX_CONFIG = {
"FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"},
}
-EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {}
+EMPTY_REGEX_CONFIG: dict[str, dict[str, Any]] = {}
EXPECTED_METRICS_ALL = {
"FirstString": "My awesome string!",
@@ -63,7 +64,7 @@ EXPECTED_METRICS_PARTIAL = {
EXPECTED_METRICS_PARTIAL,
],
)
-def test_base64_output_consumer(expected_metrics: Dict) -> None:
+def test_base64_output_consumer(expected_metrics: dict) -> None:
"""
Make sure the Base64OutputConsumer yields valid results.
@@ -73,7 +74,7 @@ def test_base64_output_consumer(expected_metrics: Dict) -> None:
parser = Base64OutputConsumer()
assert isinstance(parser, OutputConsumer)
- def create_base64_output(expected_metrics: Dict) -> bytearray:
+ def create_base64_output(expected_metrics: dict) -> bytearray:
json_str = json.dumps(expected_metrics, indent=4)
json_b64 = base64.b64encode(json_str.encode("utf-8"))
return (
diff --git a/tests/test_backend_system.py b/tests/test_backend_system.py
index 13347c6..7a8b1de 100644
--- a/tests/test_backend_system.py
+++ b/tests/test_backend_system.py
@@ -1,14 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for system backend."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
from unittest.mock import MagicMock
import pytest
@@ -27,12 +25,12 @@ from mlia.backend.system import System
def dummy_resolver(
- values: Optional[Dict[str, str]] = None
-) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]:
+ values: dict[str, str] | None = None
+) -> Callable[[str, str, list[tuple[str | None, Param]]], str]:
"""Return dummy parameter resolver implementation."""
# pylint: disable=unused-argument
def resolver(
- param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]]
+ param: str, cmd: str, param_values: list[tuple[str | None, Param]]
) -> str:
"""Implement dummy parameter resolver."""
return values.get(param, "") if values else ""
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index bf17339..eaa08e6 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.commands module."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
-from typing import Optional
from unittest.mock import call
from unittest.mock import MagicMock
@@ -165,7 +166,7 @@ def test_backend_command_action_status(installation_manager_mock: MagicMock) ->
def test_backend_command_action_add_downoad(
installation_manager_mock: MagicMock,
i_agree_to_the_contained_eula: bool,
- backend_name: Optional[str],
+ backend_name: str | None,
expected_calls: Any,
) -> None:
"""Test backend command "install" with download option."""
@@ -183,7 +184,7 @@ def test_backend_command_action_add_downoad(
def test_backend_command_action_install_from_path(
installation_manager_mock: MagicMock,
tmp_path: Path,
- backend_name: Optional[str],
+ backend_name: str | None,
) -> None:
"""Test backend command "install" with backend path."""
backend(backend_action="install", path=tmp_path, name=backend_name)
diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py
index 6d19eec..1a7cb3f 100644
--- a/tests/test_cli_config.py
+++ b/tests/test_cli_config.py
@@ -1,7 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.config module."""
-from typing import List
+from __future__ import annotations
+
from unittest.mock import MagicMock
import pytest
@@ -30,8 +31,8 @@ from mlia.cli.config import is_corstone_backend
)
def test_get_default_backends(
monkeypatch: pytest.MonkeyPatch,
- available_backends: List[str],
- expected_default_backends: List[str],
+ available_backends: list[str],
+ expected_default_backends: list[str],
) -> None:
"""Test function get_default backends."""
monkeypatch.setattr(
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 2c52885..c8aeebe 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
+from __future__ import annotations
+
from typing import Any
-from typing import Dict
-from typing import List
import pytest
@@ -67,9 +67,9 @@ class TestCliActionResolver:
],
)
def test_apply_optimizations(
- args: Dict[str, Any],
- params: Dict[str, Any],
- expected_result: List[str],
+ args: dict[str, Any],
+ params: dict[str, Any],
+ expected_result: list[str],
) -> None:
"""Test action resolving for applying optimizations."""
resolver = CLIActionResolver(args)
@@ -127,7 +127,7 @@ class TestCliActionResolver:
],
)
def test_check_performance(
- args: Dict[str, Any], expected_result: List[str]
+ args: dict[str, Any], expected_result: list[str]
) -> None:
"""Test check performance info."""
resolver = CLIActionResolver(args)
@@ -158,7 +158,7 @@ class TestCliActionResolver:
],
)
def test_check_operator_compatibility(
- args: Dict[str, Any], expected_result: List[str]
+ args: dict[str, Any], expected_result: list[str]
) -> None:
"""Test checking operator compatibility info."""
resolver = CLIActionResolver(args)
diff --git a/tests/test_cli_logging.py b/tests/test_cli_logging.py
index 5d26551..1e2cc85 100644
--- a/tests/test_cli_logging.py
+++ b/tests/test_cli_logging.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module cli.logging."""
+from __future__ import annotations
+
import logging
from pathlib import Path
-from typing import Optional
import pytest
@@ -78,7 +79,7 @@ def test_setup_logging(
def check_log_assertions(
- logs_dir_path: Optional[Path], expected_log_file_content: str
+ logs_dir_path: Path | None, expected_log_file_content: str
) -> None:
"""Test assertions for log file."""
if logs_dir_path is not None:
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 28abc7b..78adc53 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for main module."""
+from __future__ import annotations
+
import argparse
from functools import wraps
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import List
from unittest.mock import ANY
from unittest.mock import call
from unittest.mock import MagicMock
@@ -252,7 +253,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
],
)
def test_commands_execution(
- monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any
+ monkeypatch: pytest.MonkeyPatch, params: list[str], expected_call: Any
) -> None:
"""Test calling commands from the main function."""
mock = MagicMock()
@@ -320,7 +321,7 @@ def test_verbose_output(
capsys: pytest.CaptureFixture,
verbose: bool,
exc_mock: MagicMock,
- expected_output: List[str],
+ expected_output: list[str],
) -> None:
"""Test flag --verbose."""
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index a441e58..f898146 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -1,13 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module options."""
+from __future__ import annotations
+
import argparse
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
import pytest
@@ -137,7 +136,7 @@ def test_parse_optimization_parameters(
],
],
)
-def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None:
+def test_get_target_opts(args: dict | None, expected_opts: list[str]) -> None:
"""Test getting target options."""
assert get_target_profile_opts(args) == expected_opts
@@ -153,7 +152,7 @@ def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None
[["--output", "some_folder/report.csv"], "some_folder/report.csv"],
],
)
-def test_output_options(output_parameters: List[str], expected_path: str) -> None:
+def test_output_options(output_parameters: list[str], expected_path: str) -> None:
"""Test output options resolving."""
parser = argparse.ArgumentParser()
add_output_options(parser)
diff --git a/tests/test_core_advice_generation.py b/tests/test_core_advice_generation.py
index 05db698..f5e2960 100644
--- a/tests/test_core_advice_generation.py
+++ b/tests/test_core_advice_generation.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module advice_generation."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -46,7 +46,7 @@ def test_advice_generation() -> None:
)
def test_advice_category_decorator(
category: AdviceCategory,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
dummy_context: Context,
) -> None:
"""Test for advice_category decorator."""
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index d7a6ade..da7998c 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -1,13 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reporting module."""
-from typing import List
-from typing import Optional
+from __future__ import annotations
import pytest
-from mlia.core._typing import OutputFormat
-from mlia.core._typing import PathOrFileLike
from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
@@ -19,6 +16,8 @@ from mlia.core.reporting import ReportItem
from mlia.core.reporting import resolve_output_format
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.core.typing import OutputFormat
+from mlia.core.typing import PathOrFileLike
from mlia.utils.console import remove_ascii_codes
@@ -370,7 +369,7 @@ def test_nested_report_representation(
report: NestedReport,
expected_plain_text: str,
expected_json_data: dict,
- expected_csv_data: List,
+ expected_csv_data: list,
) -> None:
"""Test representation of the NestedReport."""
plain_text = report.to_plain_text()
@@ -429,7 +428,7 @@ Single row example:
],
)
def test_resolve_output_format(
- output: Optional[PathOrFileLike], expected_output_format: OutputFormat
+ output: PathOrFileLike | None, expected_output_format: OutputFormat
) -> None:
"""Test function resolve_output_format."""
assert resolve_output_format(output) == expected_output_format
diff --git a/tests/test_devices_ethosu_advice_generation.py b/tests/test_devices_ethosu_advice_generation.py
index 5d37376..5a49089 100644
--- a/tests/test_devices_ethosu_advice_generation.py
+++ b/tests/test_devices_ethosu_advice_generation.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U advice generation."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -363,7 +363,7 @@ from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
def test_ethosu_advice_producer(
tmpdir: str,
input_data: DataItem,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
advice_category: AdviceCategory,
action_resolver: ActionResolver,
) -> None:
@@ -468,7 +468,7 @@ def test_ethosu_static_advice_producer(
tmpdir: str,
advice_category: AdviceCategory,
action_resolver: ActionResolver,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
) -> None:
"""Test static advice generation."""
producer = EthosUStaticAdviceProducer()
diff --git a/tests/test_devices_ethosu_config.py b/tests/test_devices_ethosu_config.py
index 49c999a..d4e043f 100644
--- a/tests/test_devices_ethosu_config.py
+++ b/tests/test_devices_ethosu_config.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for config module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from typing import Any
-from typing import Dict
from unittest.mock import MagicMock
import pytest
@@ -113,7 +114,7 @@ def test_get_target() -> None:
],
)
def test_ethosu_configuration(
- monkeypatch: pytest.MonkeyPatch, profile_data: Dict[str, Any], expected_error: Any
+ monkeypatch: pytest.MonkeyPatch, profile_data: dict[str, Any], expected_error: Any
) -> None:
"""Test creating Ethos-U configuration."""
monkeypatch.setattr(
diff --git a/tests/test_devices_ethosu_data_analysis.py b/tests/test_devices_ethosu_data_analysis.py
index 4b1d38b..26aae76 100644
--- a/tests/test_devices_ethosu_data_analysis.py
+++ b/tests/test_devices_ethosu_data_analysis.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U data analysis module."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -139,7 +139,7 @@ def test_perf_metrics_diff() -> None:
],
)
def test_ethos_u_data_analyzer(
- input_data: DataItem, expected_facts: List[Fact]
+ input_data: DataItem, expected_facts: list[Fact]
) -> None:
"""Test Ethos-U data analyzer."""
analyzer = EthosUDataAnalyzer()
diff --git a/tests/test_devices_ethosu_reporters.py b/tests/test_devices_ethosu_reporters.py
index a63db1c..f8a7d86 100644
--- a/tests/test_devices_ethosu_reporters.py
+++ b/tests/test_devices_ethosu_reporters.py
@@ -1,14 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reports module."""
+from __future__ import annotations
+
import json
import sys
from contextlib import ExitStack as doesnt_raise
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import Dict
-from typing import List
from typing import Literal
import pytest
@@ -91,7 +91,7 @@ from mlia.utils.console import remove_ascii_codes
)
def test_report(
data: Any,
- formatters: List[Callable],
+ formatters: list[Callable],
fmt: Literal["plain_text", "json", "csv"],
output: Any,
expected_error: Any,
@@ -202,10 +202,10 @@ Operators:
],
)
def test_report_operators(
- ops: List[Operator],
+ ops: list[Operator],
expected_plain_text: str,
- expected_json_dict: Dict,
- expected_csv_list: List,
+ expected_json_dict: dict,
+ expected_csv_list: list,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test report_operatos formatter."""
@@ -380,8 +380,8 @@ def test_report_operators(
def test_report_device_details(
device: EthosUConfiguration,
expected_plain_text: str,
- expected_json_dict: Dict,
- expected_csv_list: List,
+ expected_json_dict: dict,
+ expected_csv_list: list,
) -> None:
"""Test report_operatos formatter."""
report = report_device_details(device)
diff --git a/tests/test_devices_tosa_advice_generation.py b/tests/test_devices_tosa_advice_generation.py
index 018ba57..1b97c8b 100644
--- a/tests/test_devices_tosa_advice_generation.py
+++ b/tests/test_devices_tosa_advice_generation.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for advice generation."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -40,7 +40,7 @@ def test_tosa_advice_producer(
tmpdir: str,
input_data: DataItem,
advice_category: AdviceCategory,
- expected_advice: List[Advice],
+ expected_advice: list[Advice],
) -> None:
"""Test TOSA advice producer."""
producer = TOSAAdviceProducer()
diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py
index 60bcee8..ff95978 100644
--- a/tests/test_devices_tosa_data_analysis.py
+++ b/tests/test_devices_tosa_data_analysis.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA data analysis module."""
-from typing import List
+from __future__ import annotations
import pytest
@@ -26,7 +26,7 @@ from mlia.devices.tosa.operators import TOSACompatibilityInfo
],
],
)
-def test_tosa_data_analyzer(input_data: DataItem, expected_facts: List[Fact]) -> None:
+def test_tosa_data_analyzer(input_data: DataItem, expected_facts: list[Fact]) -> None:
"""Test TOSA data analyzer."""
analyzer = TOSADataAnalyzer()
analyzer.analyze_data(input_data)
diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py
index b7736d2..d4372aa 100644
--- a/tests/test_devices_tosa_operators.py
+++ b/tests/test_devices_tosa_operators.py
@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA compatibility."""
+from __future__ import annotations
+
from pathlib import Path
from types import SimpleNamespace
from typing import Any
-from typing import Optional
from unittest.mock import MagicMock
import pytest
@@ -15,7 +16,7 @@ from mlia.devices.tosa.operators import TOSACompatibilityInfo
def replace_get_tosa_checker_with_mock(
- monkeypatch: pytest.MonkeyPatch, mock: Optional[MagicMock]
+ monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None
) -> None:
"""Replace TOSA checker with mock."""
monkeypatch.setattr(
diff --git a/tests/test_nn_tensorflow_optimizations_clustering.py b/tests/test_nn_tensorflow_optimizations_clustering.py
index c12a1e8..13dfb31 100644
--- a/tests/test_nn_tensorflow_optimizations_clustering.py
+++ b/tests/test_nn_tensorflow_optimizations_clustering.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module optimizations/clustering."""
+from __future__ import annotations
+
from pathlib import Path
-from typing import List
-from typing import Optional
import pytest
import tensorflow as tf
@@ -21,7 +21,7 @@ from tests.utils.common import train_model
def _prune_model(
- model: tf.keras.Model, target_sparsity: float, layers_to_prune: Optional[List[str]]
+ model: tf.keras.Model, target_sparsity: float, layers_to_prune: list[str] | None
) -> tf.keras.Model:
x_train, y_train = get_dataset()
batch_size = 1
@@ -47,7 +47,7 @@ def _prune_model(
def _test_num_unique_weights(
metrics: TFLiteMetrics,
target_num_clusters: int,
- layers_to_cluster: Optional[List[str]],
+ layers_to_cluster: list[str] | None,
) -> None:
clustered_uniqueness_dict = metrics.num_unique_weights(
ReportClusterMode.NUM_CLUSTERS_PER_AXIS
@@ -71,7 +71,7 @@ def _test_num_unique_weights(
def _test_sparsity(
metrics: TFLiteMetrics,
target_sparsity: float,
- layers_to_cluster: Optional[List[str]],
+ layers_to_cluster: list[str] | None,
) -> None:
pruned_sparsity_dict = metrics.sparsity_per_layer()
num_sparse_layers = 0
@@ -95,7 +95,7 @@ def _test_sparsity(
def test_cluster_simple_model_fully(
target_num_clusters: int,
sparsity_aware: bool,
- layers_to_cluster: Optional[List[str]],
+ layers_to_cluster: list[str] | None,
tmp_path: Path,
test_keras_model: Path,
) -> None:
diff --git a/tests/test_nn_tensorflow_optimizations_pruning.py b/tests/test_nn_tensorflow_optimizations_pruning.py
index 5d92f5e..d97b3d3 100644
--- a/tests/test_nn_tensorflow_optimizations_pruning.py
+++ b/tests/test_nn_tensorflow_optimizations_pruning.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module optimizations/pruning."""
+from __future__ import annotations
+
from pathlib import Path
-from typing import List
-from typing import Optional
import pytest
import tensorflow as tf
@@ -21,7 +21,7 @@ from tests.utils.common import train_model
def _test_sparsity(
metrics: TFLiteMetrics,
target_sparsity: float,
- layers_to_prune: Optional[List[str]],
+ layers_to_prune: list[str] | None,
) -> None:
pruned_sparsity_dict = metrics.sparsity_per_layer()
num_sparse_layers = 0
@@ -62,7 +62,7 @@ def _get_tflite_metrics(
def test_prune_simple_model_fully(
target_sparsity: float,
mock_data: bool,
- layers_to_prune: Optional[List[str]],
+ layers_to_prune: list[str] | None,
tmp_path: Path,
test_keras_model: Path,
) -> None:
diff --git a/tests/test_nn_tensorflow_optimizations_select.py b/tests/test_nn_tensorflow_optimizations_select.py
index 5cac8ba..e22a9d8 100644
--- a/tests/test_nn_tensorflow_optimizations_select.py
+++ b/tests/test_nn_tensorflow_optimizations_select.py
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module select."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import List
-from typing import Tuple
import pytest
import tensorflow as tf
@@ -187,7 +187,7 @@ def test_get_optimizer(
],
)
def test_optimization_settings_create_from(
- params: List[Tuple[str, float]], expected_result: List[OptimizationSettings]
+ params: list[tuple[str, float]], expected_result: list[OptimizationSettings]
) -> None:
"""Test creating settings from parsed params."""
assert OptimizationSettings.create_from(params) == expected_result
diff --git a/tests/test_nn_tensorflow_tflite_metrics.py b/tests/test_nn_tensorflow_tflite_metrics.py
index 00eacef..a5e7736 100644
--- a/tests/test_nn_tensorflow_tflite_metrics.py
+++ b/tests/test_nn_tensorflow_tflite_metrics.py
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module utils/tflite_metrics."""
+from __future__ import annotations
+
import os
import tempfile
from math import isclose
from pathlib import Path
from typing import Generator
-from typing import List
import numpy as np
import pytest
@@ -31,7 +32,7 @@ def _dummy_keras_model() -> tf.keras.Model:
def _sparse_binary_keras_model() -> tf.keras.Model:
- def get_sparse_weights(shape: List[int]) -> np.ndarray:
+ def get_sparse_weights(shape: list[int]) -> np.ndarray:
weights = np.zeros(shape)
with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator:
for idx, value in enumerate(weight_iterator):
diff --git a/tests/test_tools_metadata_common.py b/tests/test_tools_metadata_common.py
index 7663b83..69bc3e5 100644
--- a/tests/test_tools_metadata_common.py
+++ b/tests/test_tools_metadata_common.py
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for commmon installation related functions."""
+from __future__ import annotations
+
from pathlib import Path
from typing import Any
-from typing import List
-from typing import Optional
from unittest.mock import call
from unittest.mock import MagicMock
from unittest.mock import PropertyMock
@@ -22,7 +22,7 @@ def get_installation_mock(
name: str,
already_installed: bool = False,
could_be_installed: bool = False,
- supported_install_type: Optional[type] = None,
+ supported_install_type: type | None = None,
) -> MagicMock:
"""Get mock instance for the installation."""
mock = MagicMock(spec=Installation)
@@ -81,7 +81,7 @@ def _could_be_installed_from_mock() -> MagicMock:
def get_installation_manager(
noninteractive: bool,
- installations: List[Any],
+ installations: list[Any],
monkeypatch: pytest.MonkeyPatch,
yes_response: bool = True,
) -> DefaultInstallationManager:
@@ -146,7 +146,7 @@ def test_installation_manager_download_and_install(
install_mock: MagicMock,
noninteractive: bool,
eula_agreement: bool,
- backend_name: Optional[str],
+ backend_name: str | None,
expected_call: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -183,7 +183,7 @@ def test_installation_manager_download_and_install(
def test_installation_manager_install_from(
install_mock: MagicMock,
noninteractive: bool,
- backend_name: Optional[str],
+ backend_name: str | None,
expected_call: Any,
monkeypatch: pytest.MonkeyPatch,
) -> None:
diff --git a/tests/test_tools_metadata_corstone.py b/tests/test_tools_metadata_corstone.py
index 017d0c7..e2b2ae5 100644
--- a/tests/test_tools_metadata_corstone.py
+++ b/tests/test_tools_metadata_corstone.py
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Corstone related installation functions.."""
+from __future__ import annotations
+
import tarfile
from pathlib import Path
-from typing import List
-from typing import Optional
from unittest.mock import MagicMock
import pytest
@@ -44,12 +44,12 @@ def get_backend_installation( # pylint: disable=too-many-arguments
backend_runner_mock: MagicMock = MagicMock(),
name: str = "test_name",
description: str = "test_description",
- download_artifact: Optional[MagicMock] = None,
+ download_artifact: MagicMock | None = None,
path_checker: PathChecker = MagicMock(),
- apps_resources: Optional[List[str]] = None,
- system_config: Optional[str] = None,
+ apps_resources: list[str] | None = None,
+ system_config: str | None = None,
backend_installer: BackendInstaller = MagicMock(),
- supported_platforms: Optional[List[str]] = None,
+ supported_platforms: list[str] | None = None,
) -> BackendInstallation:
"""Get backend installation."""
return BackendInstallation(
@@ -79,7 +79,7 @@ def get_backend_installation( # pylint: disable=too-many-arguments
)
def test_could_be_installed_depends_on_platform(
platform: str,
- supported_platforms: Optional[List[str]],
+ supported_platforms: list[str] | None,
expected_result: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -309,7 +309,7 @@ def test_backend_installation_download_and_install(
],
)
def test_corstone_path_checker_valid_path(
- tmp_path: Path, dir_content: List[str], expected_result: Optional[str]
+ tmp_path: Path, dir_content: list[str], expected_result: str | None
) -> None:
"""Test Corstone path checker valid scenario."""
path_checker = PackagePathChecker(["file1.txt", "file2.txt"], "models")
@@ -333,7 +333,7 @@ def test_corstone_path_checker_valid_path(
@pytest.mark.parametrize("system_config", [None, "system_config"])
@pytest.mark.parametrize("copy_source", [True, False])
def test_static_path_checker(
- tmp_path: Path, copy_source: bool, system_config: Optional[str]
+ tmp_path: Path, copy_source: bool, system_config: str | None
) -> None:
"""Test static path checker."""
static_checker = StaticPathChecker(
@@ -404,7 +404,7 @@ def test_corstone_300_installer(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
eula_agreement: bool,
- expected_command: List[str],
+ expected_command: list[str],
) -> None:
"""Test Corstone-300 installer."""
command_mock = MagicMock()
diff --git a/tests/test_utils_console.py b/tests/test_utils_console.py
index 36975f8..5b01403 100644
--- a/tests/test_utils_console.py
+++ b/tests/test_utils_console.py
@@ -1,9 +1,9 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for console utility functions."""
+from __future__ import annotations
+
from typing import Iterable
-from typing import List
-from typing import Optional
import pytest
@@ -44,7 +44,7 @@ from mlia.utils.console import remove_ascii_codes
],
)
def test_produce_table(
- rows: Iterable, headers: Optional[List[str]], table_style: str, expected_result: str
+ rows: Iterable, headers: list[str] | None, table_style: str, expected_result: str
) -> None:
"""Test produce_table function."""
result = produce_table(rows, headers, table_style)
diff --git a/tests/test_utils_download.py b/tests/test_utils_download.py
index 4f8e2dc..28af74f 100644
--- a/tests/test_utils_download.py
+++ b/tests/test_utils_download.py
@@ -1,11 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for download functionality."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import Iterable
-from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import PropertyMock
@@ -17,7 +18,7 @@ from mlia.utils.download import DownloadArtifact
def response_mock(
- content_length: Optional[str], content_chunks: Iterable[bytes]
+ content_length: str | None, content_chunks: Iterable[bytes]
) -> MagicMock:
"""Mock response object."""
mock = MagicMock(spec=requests.Response)
@@ -59,9 +60,9 @@ def test_download(
show_progress: bool,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
- content_length: Optional[str],
+ content_length: str | None,
content_chunks: Iterable[bytes],
- label: Optional[str],
+ label: str | None,
) -> None:
"""Test function download."""
monkeypatch.setattr(
@@ -97,7 +98,7 @@ def test_download(
)
def test_download_artifact_download_to(
monkeypatch: pytest.MonkeyPatch,
- content_length: Optional[str],
+ content_length: str | None,
content_chunks: Iterable[bytes],
sha256_hash: str,
expected_error: Any,
diff --git a/tests/test_utils_logging.py b/tests/test_utils_logging.py
index 75ebceb..1e212b2 100644
--- a/tests/test_utils_logging.py
+++ b/tests/test_utils_logging.py
@@ -1,12 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the logging utility functions."""
+from __future__ import annotations
+
import logging
import sys
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
-from typing import Optional
import pytest
@@ -43,9 +44,9 @@ from mlia.cli.logging import create_log_handler
],
)
def test_create_log_handler(
- file_path: Optional[Path],
- stream: Optional[Any],
- log_filter: Optional[logging.Filter],
+ file_path: Path | None,
+ stream: Any | None,
+ log_filter: logging.Filter | None,
delay: bool,
expected_error: Any,
expected_class: type,
diff --git a/tests/test_utils_types.py b/tests/test_utils_types.py
index 4909efe..f7e0de8 100644
--- a/tests/test_utils_types.py
+++ b/tests/test_utils_types.py
@@ -1,9 +1,10 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the types related utility functions."""
+from __future__ import annotations
+
from typing import Any
from typing import Iterable
-from typing import Optional
import pytest
@@ -42,7 +43,7 @@ def test_is_number(value: str, expected_result: bool) -> None:
],
)
def test_is_list(
- data: Any, cls: type, elem_num: Optional[int], expected_result: bool
+ data: Any, cls: type, elem_num: int | None, expected_result: bool
) -> None:
"""Test function is_list."""
assert is_list_of(data, cls, elem_num) == expected_result
@@ -70,8 +71,6 @@ def test_only_one_selected(options: Iterable[bool], expected_result: bool) -> No
[None, 11, 11],
],
)
-def test_parse_int(
- value: Any, default: Optional[int], expected_int: Optional[int]
-) -> None:
+def test_parse_int(value: Any, default: int | None, expected_int: int | None) -> None:
"""Test function parse_int."""
assert parse_int(value, default) == expected_int
diff --git a/tests/utils/common.py b/tests/utils/common.py
index 932343e..616a407 100644
--- a/tests/utils/common.py
+++ b/tests/utils/common.py
@@ -1,13 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Common test utils module."""
-from typing import Tuple
+from __future__ import annotations
import numpy as np
import tensorflow as tf
-def get_dataset() -> Tuple[np.ndarray, np.ndarray]:
+def get_dataset() -> tuple[np.ndarray, np.ndarray]:
"""Return sample dataset."""
mnist = tf.keras.datasets.mnist
(x_train, y_train), _ = mnist.load_data()