diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-09-08 14:24:39 +0100 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-09-09 17:21:48 +0100 |
commit | f5b293d0927506c2a979a091bf0d07ecc78fa181 (patch) | |
tree | 4de585b7cb6ed34da8237063752270189a730a41 /tests | |
parent | cde0c6ee140bd108849bff40467d8f18ffc332ef (diff) | |
download | mlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz |
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation
- Use builtin types for type hints whenever possible
- Use | syntax for union types
- Rename mlia.core._typing into mlia.core.typing
Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'tests')
32 files changed, 152 insertions, 150 deletions
diff --git a/tests/test_backend_application.py b/tests/test_backend_application.py index 6860ecb..9606802 100644 --- a/tests/test_backend_application.py +++ b/tests/test_backend_application.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 # pylint: disable=no-self-use """Tests for the application backend.""" +from __future__ import annotations + from collections import Counter from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any -from typing import List from unittest.mock import MagicMock import pytest @@ -289,7 +290,7 @@ class TestApplication: ), ) def test_remove_unused_params( - self, config: ApplicationConfig, expected_params: List[Param] + self, config: ApplicationConfig, expected_params: list[Param] ) -> None: """Test mod remove_unused_parameter.""" application = Application(config) diff --git a/tests/test_backend_common.py b/tests/test_backend_common.py index 0533ef6..d11261e 100644 --- a/tests/test_backend_common.py +++ b/tests/test_backend_common.py @@ -2,16 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 # pylint: disable=no-self-use,protected-access """Tests for the common backend module.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast -from typing import Dict from typing import IO from typing import List -from typing import Optional -from typing import Tuple -from typing import Union from unittest.mock import MagicMock import pytest @@ -62,7 +60,7 @@ def test_load_config( ) -> None: """Test load_config.""" with expected_exception: - configs: List[Optional[Union[Path, IO[bytes]]]] = ( + configs: list[Path | IO[bytes] | None] = ( [None] if not filename else [ @@ -283,8 +281,8 @@ class TestBackend: def test_resolved_parameters( self, class_: type, - config: Dict, - expected_output: List[Tuple[Optional[str], Param]], + config: dict, + expected_output: list[tuple[str | None, Param]], ) -> None: """Test command building.""" backend = class_(config) @@ -343,7 +341,7 @@ class TestBackend: ], ) def test__parse_raw_parameter( - self, input_param: str, expected: Tuple[str, Optional[str]] + self, input_param: str, expected: tuple[str, str | None] ) -> None: """Test internal method of parsing a single raw parameter.""" assert parse_raw_parameter(input_param) == expected @@ -476,7 +474,7 @@ class TestCommand: ], ], ) - def test_validate_params(self, params: List[Param], expected_error: Any) -> None: + def test_validate_params(self, params: list[Param], expected_error: Any) -> None: """Test command validation function.""" with expected_error: Command([], params) diff --git a/tests/test_backend_fs.py b/tests/test_backend_fs.py index 7423222..21226a9 100644 --- a/tests/test_backend_fs.py +++ b/tests/test_backend_fs.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 # pylint: disable=no-self-use """Module for testing fs.py.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any -from typing import Union from unittest.mock import MagicMock import pytest @@ -108,7 +109,7 @@ def test_recreate_directory(tmpdir: Any) -> None: def write_to_file( - write_directory: Any, write_mode: str, write_text: Union[str, bytes] + write_directory: Any, write_mode: str, write_text: str | bytes ) -> Path: """Write some text to a temporary test file.""" tmpdir_path = Path(write_directory) diff --git a/tests/test_backend_manager.py b/tests/test_backend_manager.py index 1b5fea1..a1e9198 100644 --- a/tests/test_backend_manager.py +++ b/tests/test_backend_manager.py @@ -1,16 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module backend/manager.""" +from __future__ import annotations + import base64 import json from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Set -from typing import Tuple from unittest.mock import MagicMock from unittest.mock import PropertyMock @@ -35,7 +32,7 @@ from mlia.backend.output_consumer import Base64OutputConsumer from mlia.backend.system import get_system -def _mock_encode_b64(data: Dict[str, int]) -> str: +def _mock_encode_b64(data: dict[str, int]) -> str: """ Encode the given data into a mock base64-encoded string of JSON. @@ -138,7 +135,7 @@ def _mock_encode_b64(data: Dict[str, int]) -> str: ], ) def test_generic_inference_output_parser( - data: Dict[str, int], is_ready: bool, result: Dict, missed_keys: Set[str] + data: dict[str, int], is_ready: bool, result: dict, missed_keys: set[str] ) -> None: """Test generic runner output parser.""" parser = GenericInferenceOutputParser() @@ -157,8 +154,8 @@ class TestBackendRunner: @staticmethod def _setup_backends( monkeypatch: pytest.MonkeyPatch, - available_systems: Optional[List[str]] = None, - available_apps: Optional[List[str]] = None, + available_systems: list[str] | None = None, + available_apps: list[str] | None = None, ) -> None: """Set up backend metadata.""" @@ -196,7 +193,7 @@ class TestBackendRunner: ) def test_is_system_installed( self, - available_systems: List, + available_systems: list, system: str, installed: bool, monkeypatch: pytest.MonkeyPatch, @@ -217,8 +214,8 @@ class TestBackendRunner: ) def test_installed_systems( self, - available_systems: List[str], - systems: List[str], + available_systems: list[str], + systems: list[str], monkeypatch: pytest.MonkeyPatch, ) -> None: """Test method installed_systems.""" @@ -250,8 +247,8 @@ class TestBackendRunner: ) def test_systems_installed( self, - available_systems: List[str], - systems: List[str], + available_systems: list[str], + systems: list[str], expected_result: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -274,8 +271,8 @@ class TestBackendRunner: ) def test_applications_installed( self, - available_apps: List[str], - applications: List[str], + available_apps: list[str], + applications: list[str], expected_result: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -297,8 +294,8 @@ class TestBackendRunner: ) def test_get_installed_applications( self, - available_apps: List[str], - applications: List[str], + available_apps: list[str], + applications: list[str], monkeypatch: pytest.MonkeyPatch, ) -> None: """Test method get_installed_applications.""" @@ -337,7 +334,7 @@ class TestBackendRunner: ) def test_is_application_installed( self, - available_apps: List[str], + available_apps: list[str], application: str, installed: bool, monkeypatch: pytest.MonkeyPatch, @@ -377,7 +374,7 @@ class TestBackendRunner: def test_run_application_local( monkeypatch: pytest.MonkeyPatch, execution_params: ExecutionParams, - expected_command: List[str], + expected_command: list[str], ) -> None: """Test method run_application with local systems.""" run_app = MagicMock(wraps=run_application) @@ -491,8 +488,8 @@ class TestBackendRunner: ) def test_estimate_performance( device: DeviceInfo, - system: Tuple[str, bool], - application: Tuple[str, bool], + system: tuple[str, bool], + application: tuple[str, bool], backend: str, expected_error: Any, test_tflite_model: Path, @@ -588,7 +585,7 @@ def test_estimate_performance_invalid_output( ) -def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock: +def create_mock_process(stdout: list[str], stderr: list[str]) -> MagicMock: """Mock underlying process.""" mock_process = MagicMock() mock_process.poll.return_value = 0 @@ -597,7 +594,7 @@ def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock: return mock_process -def create_mock_context(stdout: List[str]) -> ExecutionContext: +def create_mock_context(stdout: list[str]) -> ExecutionContext: """Mock ExecutionContext.""" ctx = ExecutionContext( app=get_application("application_1")[0], diff --git a/tests/test_backend_output_consumer.py b/tests/test_backend_output_consumer.py index 881112e..2ecb07f 100644 --- a/tests/test_backend_output_consumer.py +++ b/tests/test_backend_output_consumer.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the output parsing.""" +from __future__ import annotations + import base64 import json from typing import Any -from typing import Dict import pytest @@ -42,7 +43,7 @@ REGEX_CONFIG = { "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, } -EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} +EMPTY_REGEX_CONFIG: dict[str, dict[str, Any]] = {} EXPECTED_METRICS_ALL = { "FirstString": "My awesome string!", @@ -63,7 +64,7 @@ EXPECTED_METRICS_PARTIAL = { EXPECTED_METRICS_PARTIAL, ], ) -def test_base64_output_consumer(expected_metrics: Dict) -> None: +def test_base64_output_consumer(expected_metrics: dict) -> None: """ Make sure the Base64OutputConsumer yields valid results. @@ -73,7 +74,7 @@ def test_base64_output_consumer(expected_metrics: Dict) -> None: parser = Base64OutputConsumer() assert isinstance(parser, OutputConsumer) - def create_base64_output(expected_metrics: Dict) -> bytearray: + def create_base64_output(expected_metrics: dict) -> bytearray: json_str = json.dumps(expected_metrics, indent=4) json_b64 = base64.b64encode(json_str.encode("utf-8")) return ( diff --git a/tests/test_backend_system.py b/tests/test_backend_system.py index 13347c6..7a8b1de 100644 --- a/tests/test_backend_system.py +++ b/tests/test_backend_system.py @@ -1,14 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for system backend.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple from unittest.mock import MagicMock import pytest @@ -27,12 +25,12 @@ from mlia.backend.system import System def dummy_resolver( - values: Optional[Dict[str, str]] = None -) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]: + values: dict[str, str] | None = None +) -> Callable[[str, str, list[tuple[str | None, Param]]], str]: """Return dummy parameter resolver implementation.""" # pylint: disable=unused-argument def resolver( - param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]] + param: str, cmd: str, param_values: list[tuple[str | None, Param]] ) -> str: """Implement dummy parameter resolver.""" return values.get(param, "") if values else "" diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index bf17339..eaa08e6 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for cli.commands module.""" +from __future__ import annotations + from pathlib import Path from typing import Any -from typing import Optional from unittest.mock import call from unittest.mock import MagicMock @@ -165,7 +166,7 @@ def test_backend_command_action_status(installation_manager_mock: MagicMock) -> def test_backend_command_action_add_downoad( installation_manager_mock: MagicMock, i_agree_to_the_contained_eula: bool, - backend_name: Optional[str], + backend_name: str | None, expected_calls: Any, ) -> None: """Test backend command "install" with download option.""" @@ -183,7 +184,7 @@ def test_backend_command_action_add_downoad( def test_backend_command_action_install_from_path( installation_manager_mock: MagicMock, tmp_path: Path, - backend_name: Optional[str], + backend_name: str | None, ) -> None: """Test backend command "install" with backend path.""" backend(backend_action="install", path=tmp_path, name=backend_name) diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py index 6d19eec..1a7cb3f 100644 --- a/tests/test_cli_config.py +++ b/tests/test_cli_config.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for cli.config module.""" -from typing import List +from __future__ import annotations + from unittest.mock import MagicMock import pytest @@ -30,8 +31,8 @@ from mlia.cli.config import is_corstone_backend ) def test_get_default_backends( monkeypatch: pytest.MonkeyPatch, - available_backends: List[str], - expected_default_backends: List[str], + available_backends: list[str], + expected_default_backends: list[str], ) -> None: """Test function get_default backends.""" monkeypatch.setattr( diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index 2c52885..c8aeebe 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the helper classes.""" +from __future__ import annotations + from typing import Any -from typing import Dict -from typing import List import pytest @@ -67,9 +67,9 @@ class TestCliActionResolver: ], ) def test_apply_optimizations( - args: Dict[str, Any], - params: Dict[str, Any], - expected_result: List[str], + args: dict[str, Any], + params: dict[str, Any], + expected_result: list[str], ) -> None: """Test action resolving for applying optimizations.""" resolver = CLIActionResolver(args) @@ -127,7 +127,7 @@ class TestCliActionResolver: ], ) def test_check_performance( - args: Dict[str, Any], expected_result: List[str] + args: dict[str, Any], expected_result: list[str] ) -> None: """Test check performance info.""" resolver = CLIActionResolver(args) @@ -158,7 +158,7 @@ class TestCliActionResolver: ], ) def test_check_operator_compatibility( - args: Dict[str, Any], expected_result: List[str] + args: dict[str, Any], expected_result: list[str] ) -> None: """Test checking operator compatibility info.""" resolver = CLIActionResolver(args) diff --git a/tests/test_cli_logging.py b/tests/test_cli_logging.py index 5d26551..1e2cc85 100644 --- a/tests/test_cli_logging.py +++ b/tests/test_cli_logging.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the module cli.logging.""" +from __future__ import annotations + import logging from pathlib import Path -from typing import Optional import pytest @@ -78,7 +79,7 @@ def test_setup_logging( def check_log_assertions( - logs_dir_path: Optional[Path], expected_log_file_content: str + logs_dir_path: Path | None, expected_log_file_content: str ) -> None: """Test assertions for log file.""" if logs_dir_path is not None: diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 28abc7b..78adc53 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -1,12 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for main module.""" +from __future__ import annotations + import argparse from functools import wraps from pathlib import Path from typing import Any from typing import Callable -from typing import List from unittest.mock import ANY from unittest.mock import call from unittest.mock import MagicMock @@ -252,7 +253,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non ], ) def test_commands_execution( - monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any + monkeypatch: pytest.MonkeyPatch, params: list[str], expected_call: Any ) -> None: """Test calling commands from the main function.""" mock = MagicMock() @@ -320,7 +321,7 @@ def test_verbose_output( capsys: pytest.CaptureFixture, verbose: bool, exc_mock: MagicMock, - expected_output: List[str], + expected_output: list[str], ) -> None: """Test flag --verbose.""" diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py index a441e58..f898146 100644 --- a/tests/test_cli_options.py +++ b/tests/test_cli_options.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module options.""" +from __future__ import annotations + import argparse from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any -from typing import Dict -from typing import List -from typing import Optional import pytest @@ -137,7 +136,7 @@ def test_parse_optimization_parameters( ], ], ) -def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None: +def test_get_target_opts(args: dict | None, expected_opts: list[str]) -> None: """Test getting target options.""" assert get_target_profile_opts(args) == expected_opts @@ -153,7 +152,7 @@ def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None [["--output", "some_folder/report.csv"], "some_folder/report.csv"], ], ) -def test_output_options(output_parameters: List[str], expected_path: str) -> None: +def test_output_options(output_parameters: list[str], expected_path: str) -> None: """Test output options resolving.""" parser = argparse.ArgumentParser() add_output_options(parser) diff --git a/tests/test_core_advice_generation.py b/tests/test_core_advice_generation.py index 05db698..f5e2960 100644 --- a/tests/test_core_advice_generation.py +++ b/tests/test_core_advice_generation.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module advice_generation.""" -from typing import List +from __future__ import annotations import pytest @@ -46,7 +46,7 @@ def test_advice_generation() -> None: ) def test_advice_category_decorator( category: AdviceCategory, - expected_advice: List[Advice], + expected_advice: list[Advice], dummy_context: Context, ) -> None: """Test for advice_category decorator.""" diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py index d7a6ade..da7998c 100644 --- a/tests/test_core_reporting.py +++ b/tests/test_core_reporting.py @@ -1,13 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for reporting module.""" -from typing import List -from typing import Optional +from __future__ import annotations import pytest -from mlia.core._typing import OutputFormat -from mlia.core._typing import PathOrFileLike from mlia.core.reporting import BytesCell from mlia.core.reporting import Cell from mlia.core.reporting import ClockCell @@ -19,6 +16,8 @@ from mlia.core.reporting import ReportItem from mlia.core.reporting import resolve_output_format from mlia.core.reporting import SingleRow from mlia.core.reporting import Table +from mlia.core.typing import OutputFormat +from mlia.core.typing import PathOrFileLike from mlia.utils.console import remove_ascii_codes @@ -370,7 +369,7 @@ def test_nested_report_representation( report: NestedReport, expected_plain_text: str, expected_json_data: dict, - expected_csv_data: List, + expected_csv_data: list, ) -> None: """Test representation of the NestedReport.""" plain_text = report.to_plain_text() @@ -429,7 +428,7 @@ Single row example: ], ) def test_resolve_output_format( - output: Optional[PathOrFileLike], expected_output_format: OutputFormat + output: PathOrFileLike | None, expected_output_format: OutputFormat ) -> None: """Test function resolve_output_format.""" assert resolve_output_format(output) == expected_output_format diff --git a/tests/test_devices_ethosu_advice_generation.py b/tests/test_devices_ethosu_advice_generation.py index 5d37376..5a49089 100644 --- a/tests/test_devices_ethosu_advice_generation.py +++ b/tests/test_devices_ethosu_advice_generation.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Ethos-U advice generation.""" -from typing import List +from __future__ import annotations import pytest @@ -363,7 +363,7 @@ from mlia.nn.tensorflow.optimizations.select import OptimizationSettings def test_ethosu_advice_producer( tmpdir: str, input_data: DataItem, - expected_advice: List[Advice], + expected_advice: list[Advice], advice_category: AdviceCategory, action_resolver: ActionResolver, ) -> None: @@ -468,7 +468,7 @@ def test_ethosu_static_advice_producer( tmpdir: str, advice_category: AdviceCategory, action_resolver: ActionResolver, - expected_advice: List[Advice], + expected_advice: list[Advice], ) -> None: """Test static advice generation.""" producer = EthosUStaticAdviceProducer() diff --git a/tests/test_devices_ethosu_config.py b/tests/test_devices_ethosu_config.py index 49c999a..d4e043f 100644 --- a/tests/test_devices_ethosu_config.py +++ b/tests/test_devices_ethosu_config.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for config module.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raise from typing import Any -from typing import Dict from unittest.mock import MagicMock import pytest @@ -113,7 +114,7 @@ def test_get_target() -> None: ], ) def test_ethosu_configuration( - monkeypatch: pytest.MonkeyPatch, profile_data: Dict[str, Any], expected_error: Any + monkeypatch: pytest.MonkeyPatch, profile_data: dict[str, Any], expected_error: Any ) -> None: """Test creating Ethos-U configuration.""" monkeypatch.setattr( diff --git a/tests/test_devices_ethosu_data_analysis.py b/tests/test_devices_ethosu_data_analysis.py index 4b1d38b..26aae76 100644 --- a/tests/test_devices_ethosu_data_analysis.py +++ b/tests/test_devices_ethosu_data_analysis.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Ethos-U data analysis module.""" -from typing import List +from __future__ import annotations import pytest @@ -139,7 +139,7 @@ def test_perf_metrics_diff() -> None: ], ) def test_ethos_u_data_analyzer( - input_data: DataItem, expected_facts: List[Fact] + input_data: DataItem, expected_facts: list[Fact] ) -> None: """Test Ethos-U data analyzer.""" analyzer = EthosUDataAnalyzer() diff --git a/tests/test_devices_ethosu_reporters.py b/tests/test_devices_ethosu_reporters.py index a63db1c..f8a7d86 100644 --- a/tests/test_devices_ethosu_reporters.py +++ b/tests/test_devices_ethosu_reporters.py @@ -1,14 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for reports module.""" +from __future__ import annotations + import json import sys from contextlib import ExitStack as doesnt_raise from pathlib import Path from typing import Any from typing import Callable -from typing import Dict -from typing import List from typing import Literal import pytest @@ -91,7 +91,7 @@ from mlia.utils.console import remove_ascii_codes ) def test_report( data: Any, - formatters: List[Callable], + formatters: list[Callable], fmt: Literal["plain_text", "json", "csv"], output: Any, expected_error: Any, @@ -202,10 +202,10 @@ Operators: ], ) def test_report_operators( - ops: List[Operator], + ops: list[Operator], expected_plain_text: str, - expected_json_dict: Dict, - expected_csv_list: List, + expected_json_dict: dict, + expected_csv_list: list, monkeypatch: pytest.MonkeyPatch, ) -> None: """Test report_operatos formatter.""" @@ -380,8 +380,8 @@ def test_report_operators( def test_report_device_details( device: EthosUConfiguration, expected_plain_text: str, - expected_json_dict: Dict, - expected_csv_list: List, + expected_json_dict: dict, + expected_csv_list: list, ) -> None: """Test report_operatos formatter.""" report = report_device_details(device) diff --git a/tests/test_devices_tosa_advice_generation.py b/tests/test_devices_tosa_advice_generation.py index 018ba57..1b97c8b 100644 --- a/tests/test_devices_tosa_advice_generation.py +++ b/tests/test_devices_tosa_advice_generation.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for advice generation.""" -from typing import List +from __future__ import annotations import pytest @@ -40,7 +40,7 @@ def test_tosa_advice_producer( tmpdir: str, input_data: DataItem, advice_category: AdviceCategory, - expected_advice: List[Advice], + expected_advice: list[Advice], ) -> None: """Test TOSA advice producer.""" producer = TOSAAdviceProducer() diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py index 60bcee8..ff95978 100644 --- a/tests/test_devices_tosa_data_analysis.py +++ b/tests/test_devices_tosa_data_analysis.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA data analysis module.""" -from typing import List +from __future__ import annotations import pytest @@ -26,7 +26,7 @@ from mlia.devices.tosa.operators import TOSACompatibilityInfo ], ], ) -def test_tosa_data_analyzer(input_data: DataItem, expected_facts: List[Fact]) -> None: +def test_tosa_data_analyzer(input_data: DataItem, expected_facts: list[Fact]) -> None: """Test TOSA data analyzer.""" analyzer = TOSADataAnalyzer() analyzer.analyze_data(input_data) diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py index b7736d2..d4372aa 100644 --- a/tests/test_devices_tosa_operators.py +++ b/tests/test_devices_tosa_operators.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA compatibility.""" +from __future__ import annotations + from pathlib import Path from types import SimpleNamespace from typing import Any -from typing import Optional from unittest.mock import MagicMock import pytest @@ -15,7 +16,7 @@ from mlia.devices.tosa.operators import TOSACompatibilityInfo def replace_get_tosa_checker_with_mock( - monkeypatch: pytest.MonkeyPatch, mock: Optional[MagicMock] + monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None ) -> None: """Replace TOSA checker with mock.""" monkeypatch.setattr( diff --git a/tests/test_nn_tensorflow_optimizations_clustering.py b/tests/test_nn_tensorflow_optimizations_clustering.py index c12a1e8..13dfb31 100644 --- a/tests/test_nn_tensorflow_optimizations_clustering.py +++ b/tests/test_nn_tensorflow_optimizations_clustering.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module optimizations/clustering.""" +from __future__ import annotations + from pathlib import Path -from typing import List -from typing import Optional import pytest import tensorflow as tf @@ -21,7 +21,7 @@ from tests.utils.common import train_model def _prune_model( - model: tf.keras.Model, target_sparsity: float, layers_to_prune: Optional[List[str]] + model: tf.keras.Model, target_sparsity: float, layers_to_prune: list[str] | None ) -> tf.keras.Model: x_train, y_train = get_dataset() batch_size = 1 @@ -47,7 +47,7 @@ def _prune_model( def _test_num_unique_weights( metrics: TFLiteMetrics, target_num_clusters: int, - layers_to_cluster: Optional[List[str]], + layers_to_cluster: list[str] | None, ) -> None: clustered_uniqueness_dict = metrics.num_unique_weights( ReportClusterMode.NUM_CLUSTERS_PER_AXIS @@ -71,7 +71,7 @@ def _test_num_unique_weights( def _test_sparsity( metrics: TFLiteMetrics, target_sparsity: float, - layers_to_cluster: Optional[List[str]], + layers_to_cluster: list[str] | None, ) -> None: pruned_sparsity_dict = metrics.sparsity_per_layer() num_sparse_layers = 0 @@ -95,7 +95,7 @@ def _test_sparsity( def test_cluster_simple_model_fully( target_num_clusters: int, sparsity_aware: bool, - layers_to_cluster: Optional[List[str]], + layers_to_cluster: list[str] | None, tmp_path: Path, test_keras_model: Path, ) -> None: diff --git a/tests/test_nn_tensorflow_optimizations_pruning.py b/tests/test_nn_tensorflow_optimizations_pruning.py index 5d92f5e..d97b3d3 100644 --- a/tests/test_nn_tensorflow_optimizations_pruning.py +++ b/tests/test_nn_tensorflow_optimizations_pruning.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module optimizations/pruning.""" +from __future__ import annotations + from pathlib import Path -from typing import List -from typing import Optional import pytest import tensorflow as tf @@ -21,7 +21,7 @@ from tests.utils.common import train_model def _test_sparsity( metrics: TFLiteMetrics, target_sparsity: float, - layers_to_prune: Optional[List[str]], + layers_to_prune: list[str] | None, ) -> None: pruned_sparsity_dict = metrics.sparsity_per_layer() num_sparse_layers = 0 @@ -62,7 +62,7 @@ def _get_tflite_metrics( def test_prune_simple_model_fully( target_sparsity: float, mock_data: bool, - layers_to_prune: Optional[List[str]], + layers_to_prune: list[str] | None, tmp_path: Path, test_keras_model: Path, ) -> None: diff --git a/tests/test_nn_tensorflow_optimizations_select.py b/tests/test_nn_tensorflow_optimizations_select.py index 5cac8ba..e22a9d8 100644 --- a/tests/test_nn_tensorflow_optimizations_select.py +++ b/tests/test_nn_tensorflow_optimizations_select.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module select.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any -from typing import List -from typing import Tuple import pytest import tensorflow as tf @@ -187,7 +187,7 @@ def test_get_optimizer( ], ) def test_optimization_settings_create_from( - params: List[Tuple[str, float]], expected_result: List[OptimizationSettings] + params: list[tuple[str, float]], expected_result: list[OptimizationSettings] ) -> None: """Test creating settings from parsed params.""" assert OptimizationSettings.create_from(params) == expected_result diff --git a/tests/test_nn_tensorflow_tflite_metrics.py b/tests/test_nn_tensorflow_tflite_metrics.py index 00eacef..a5e7736 100644 --- a/tests/test_nn_tensorflow_tflite_metrics.py +++ b/tests/test_nn_tensorflow_tflite_metrics.py @@ -1,12 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module utils/tflite_metrics.""" +from __future__ import annotations + import os import tempfile from math import isclose from pathlib import Path from typing import Generator -from typing import List import numpy as np import pytest @@ -31,7 +32,7 @@ def _dummy_keras_model() -> tf.keras.Model: def _sparse_binary_keras_model() -> tf.keras.Model: - def get_sparse_weights(shape: List[int]) -> np.ndarray: + def get_sparse_weights(shape: list[int]) -> np.ndarray: weights = np.zeros(shape) with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator: for idx, value in enumerate(weight_iterator): diff --git a/tests/test_tools_metadata_common.py b/tests/test_tools_metadata_common.py index 7663b83..69bc3e5 100644 --- a/tests/test_tools_metadata_common.py +++ b/tests/test_tools_metadata_common.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for commmon installation related functions.""" +from __future__ import annotations + from pathlib import Path from typing import Any -from typing import List -from typing import Optional from unittest.mock import call from unittest.mock import MagicMock from unittest.mock import PropertyMock @@ -22,7 +22,7 @@ def get_installation_mock( name: str, already_installed: bool = False, could_be_installed: bool = False, - supported_install_type: Optional[type] = None, + supported_install_type: type | None = None, ) -> MagicMock: """Get mock instance for the installation.""" mock = MagicMock(spec=Installation) @@ -81,7 +81,7 @@ def _could_be_installed_from_mock() -> MagicMock: def get_installation_manager( noninteractive: bool, - installations: List[Any], + installations: list[Any], monkeypatch: pytest.MonkeyPatch, yes_response: bool = True, ) -> DefaultInstallationManager: @@ -146,7 +146,7 @@ def test_installation_manager_download_and_install( install_mock: MagicMock, noninteractive: bool, eula_agreement: bool, - backend_name: Optional[str], + backend_name: str | None, expected_call: Any, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -183,7 +183,7 @@ def test_installation_manager_download_and_install( def test_installation_manager_install_from( install_mock: MagicMock, noninteractive: bool, - backend_name: Optional[str], + backend_name: str | None, expected_call: Any, monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/test_tools_metadata_corstone.py b/tests/test_tools_metadata_corstone.py index 017d0c7..e2b2ae5 100644 --- a/tests/test_tools_metadata_corstone.py +++ b/tests/test_tools_metadata_corstone.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Corstone related installation functions..""" +from __future__ import annotations + import tarfile from pathlib import Path -from typing import List -from typing import Optional from unittest.mock import MagicMock import pytest @@ -44,12 +44,12 @@ def get_backend_installation( # pylint: disable=too-many-arguments backend_runner_mock: MagicMock = MagicMock(), name: str = "test_name", description: str = "test_description", - download_artifact: Optional[MagicMock] = None, + download_artifact: MagicMock | None = None, path_checker: PathChecker = MagicMock(), - apps_resources: Optional[List[str]] = None, - system_config: Optional[str] = None, + apps_resources: list[str] | None = None, + system_config: str | None = None, backend_installer: BackendInstaller = MagicMock(), - supported_platforms: Optional[List[str]] = None, + supported_platforms: list[str] | None = None, ) -> BackendInstallation: """Get backend installation.""" return BackendInstallation( @@ -79,7 +79,7 @@ def get_backend_installation( # pylint: disable=too-many-arguments ) def test_could_be_installed_depends_on_platform( platform: str, - supported_platforms: Optional[List[str]], + supported_platforms: list[str] | None, expected_result: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -309,7 +309,7 @@ def test_backend_installation_download_and_install( ], ) def test_corstone_path_checker_valid_path( - tmp_path: Path, dir_content: List[str], expected_result: Optional[str] + tmp_path: Path, dir_content: list[str], expected_result: str | None ) -> None: """Test Corstone path checker valid scenario.""" path_checker = PackagePathChecker(["file1.txt", "file2.txt"], "models") @@ -333,7 +333,7 @@ def test_corstone_path_checker_valid_path( @pytest.mark.parametrize("system_config", [None, "system_config"]) @pytest.mark.parametrize("copy_source", [True, False]) def test_static_path_checker( - tmp_path: Path, copy_source: bool, system_config: Optional[str] + tmp_path: Path, copy_source: bool, system_config: str | None ) -> None: """Test static path checker.""" static_checker = StaticPathChecker( @@ -404,7 +404,7 @@ def test_corstone_300_installer( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, eula_agreement: bool, - expected_command: List[str], + expected_command: list[str], ) -> None: """Test Corstone-300 installer.""" command_mock = MagicMock() diff --git a/tests/test_utils_console.py b/tests/test_utils_console.py index 36975f8..5b01403 100644 --- a/tests/test_utils_console.py +++ b/tests/test_utils_console.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for console utility functions.""" +from __future__ import annotations + from typing import Iterable -from typing import List -from typing import Optional import pytest @@ -44,7 +44,7 @@ from mlia.utils.console import remove_ascii_codes ], ) def test_produce_table( - rows: Iterable, headers: Optional[List[str]], table_style: str, expected_result: str + rows: Iterable, headers: list[str] | None, table_style: str, expected_result: str ) -> None: """Test produce_table function.""" result = produce_table(rows, headers, table_style) diff --git a/tests/test_utils_download.py b/tests/test_utils_download.py index 4f8e2dc..28af74f 100644 --- a/tests/test_utils_download.py +++ b/tests/test_utils_download.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for download functionality.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import Iterable -from typing import Optional from unittest.mock import MagicMock from unittest.mock import PropertyMock @@ -17,7 +18,7 @@ from mlia.utils.download import DownloadArtifact def response_mock( - content_length: Optional[str], content_chunks: Iterable[bytes] + content_length: str | None, content_chunks: Iterable[bytes] ) -> MagicMock: """Mock response object.""" mock = MagicMock(spec=requests.Response) @@ -59,9 +60,9 @@ def test_download( show_progress: bool, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, - content_length: Optional[str], + content_length: str | None, content_chunks: Iterable[bytes], - label: Optional[str], + label: str | None, ) -> None: """Test function download.""" monkeypatch.setattr( @@ -97,7 +98,7 @@ def test_download( ) def test_download_artifact_download_to( monkeypatch: pytest.MonkeyPatch, - content_length: Optional[str], + content_length: str | None, content_chunks: Iterable[bytes], sha256_hash: str, expected_error: Any, diff --git a/tests/test_utils_logging.py b/tests/test_utils_logging.py index 75ebceb..1e212b2 100644 --- a/tests/test_utils_logging.py +++ b/tests/test_utils_logging.py @@ -1,12 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the logging utility functions.""" +from __future__ import annotations + import logging import sys from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any -from typing import Optional import pytest @@ -43,9 +44,9 @@ from mlia.cli.logging import create_log_handler ], ) def test_create_log_handler( - file_path: Optional[Path], - stream: Optional[Any], - log_filter: Optional[logging.Filter], + file_path: Path | None, + stream: Any | None, + log_filter: logging.Filter | None, delay: bool, expected_error: Any, expected_class: type, diff --git a/tests/test_utils_types.py b/tests/test_utils_types.py index 4909efe..f7e0de8 100644 --- a/tests/test_utils_types.py +++ b/tests/test_utils_types.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the types related utility functions.""" +from __future__ import annotations + from typing import Any from typing import Iterable -from typing import Optional import pytest @@ -42,7 +43,7 @@ def test_is_number(value: str, expected_result: bool) -> None: ], ) def test_is_list( - data: Any, cls: type, elem_num: Optional[int], expected_result: bool + data: Any, cls: type, elem_num: int | None, expected_result: bool ) -> None: """Test function is_list.""" assert is_list_of(data, cls, elem_num) == expected_result @@ -70,8 +71,6 @@ def test_only_one_selected(options: Iterable[bool], expected_result: bool) -> No [None, 11, 11], ], ) -def test_parse_int( - value: Any, default: Optional[int], expected_int: Optional[int] -) -> None: +def test_parse_int(value: Any, default: int | None, expected_int: int | None) -> None: """Test function parse_int.""" assert parse_int(value, default) == expected_int diff --git a/tests/utils/common.py b/tests/utils/common.py index 932343e..616a407 100644 --- a/tests/utils/common.py +++ b/tests/utils/common.py @@ -1,13 +1,13 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Common test utils module.""" -from typing import Tuple +from __future__ import annotations import numpy as np import tensorflow as tf -def get_dataset() -> Tuple[np.ndarray, np.ndarray]: +def get_dataset() -> tuple[np.ndarray, np.ndarray]: """Return sample dataset.""" mnist = tf.keras.datasets.mnist (x_train, y_train), _ = mnist.load_data() |