aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia
diff options
context:
space:
mode:
Diffstat (limited to 'tests/mlia')
-rw-r--r--tests/mlia/__init__.py3
-rw-r--r--tests/mlia/conftest.py20
-rw-r--r--tests/mlia/test_api.py96
-rw-r--r--tests/mlia/test_cli_commands.py204
-rw-r--r--tests/mlia/test_cli_config.py49
-rw-r--r--tests/mlia/test_cli_helpers.py165
-rw-r--r--tests/mlia/test_cli_logging.py104
-rw-r--r--tests/mlia/test_cli_main.py357
-rw-r--r--tests/mlia/test_cli_options.py186
-rw-r--r--tests/mlia/test_core_advice_generation.py71
-rw-r--r--tests/mlia/test_core_advisor.py40
-rw-r--r--tests/mlia/test_core_context.py62
-rw-r--r--tests/mlia/test_core_data_analysis.py31
-rw-r--r--tests/mlia/test_core_events.py155
-rw-r--r--tests/mlia/test_core_helpers.py17
-rw-r--r--tests/mlia/test_core_mixins.py99
-rw-r--r--tests/mlia/test_core_performance.py29
-rw-r--r--tests/mlia/test_core_reporting.py413
-rw-r--r--tests/mlia/test_core_workflow.py164
-rw-r--r--tests/mlia/test_devices_ethosu_advice_generation.py483
-rw-r--r--tests/mlia/test_devices_ethosu_advisor.py9
-rw-r--r--tests/mlia/test_devices_ethosu_config.py124
-rw-r--r--tests/mlia/test_devices_ethosu_data_analysis.py147
-rw-r--r--tests/mlia/test_devices_ethosu_data_collection.py151
-rw-r--r--tests/mlia/test_devices_ethosu_performance.py28
-rw-r--r--tests/mlia/test_devices_ethosu_reporters.py434
-rw-r--r--tests/mlia/test_nn_tensorflow_config.py72
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_clustering.py131
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_pruning.py117
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_select.py240
-rw-r--r--tests/mlia/test_nn_tensorflow_tflite_metrics.py137
-rw-r--r--tests/mlia/test_nn_tensorflow_utils.py81
-rw-r--r--tests/mlia/test_resources/vela/sample_vela.ini47
-rw-r--r--tests/mlia/test_tools_aiet_wrapper.py760
-rw-r--r--tests/mlia/test_tools_metadata_common.py196
-rw-r--r--tests/mlia/test_tools_metadata_corstone.py419
-rw-r--r--tests/mlia/test_tools_vela_wrapper.py285
-rw-r--r--tests/mlia/test_utils_console.py100
-rw-r--r--tests/mlia/test_utils_download.py147
-rw-r--r--tests/mlia/test_utils_filesystem.py166
-rw-r--r--tests/mlia/test_utils_logging.py63
-rw-r--r--tests/mlia/test_utils_misc.py25
-rw-r--r--tests/mlia/test_utils_proc.py149
-rw-r--r--tests/mlia/test_utils_types.py77
-rw-r--r--tests/mlia/utils/__init__.py3
-rw-r--r--tests/mlia/utils/common.py32
-rw-r--r--tests/mlia/utils/logging.py13
47 files changed, 6901 insertions, 0 deletions
diff --git a/tests/mlia/__init__.py b/tests/mlia/__init__.py
new file mode 100644
index 0000000..0687f14
--- /dev/null
+++ b/tests/mlia/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""MLIA tests module."""
diff --git a/tests/mlia/conftest.py b/tests/mlia/conftest.py
new file mode 100644
index 0000000..f683fca
--- /dev/null
+++ b/tests/mlia/conftest.py
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Pytest conf module."""
+from pathlib import Path
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+
+
+@pytest.fixture(scope="session", name="test_resources_path")
+def fixture_test_resources_path() -> Path:
+ """Return test resources path."""
+ return Path(__file__).parent / "test_resources"
+
+
+@pytest.fixture(name="dummy_context")
+def fixture_dummy_context(tmpdir: str) -> ExecutionContext:
+ """Return dummy context fixture."""
+ return ExecutionContext(working_dir=tmpdir)
diff --git a/tests/mlia/test_api.py b/tests/mlia/test_api.py
new file mode 100644
index 0000000..54d4796
--- /dev/null
+++ b/tests/mlia/test_api.py
@@ -0,0 +1,96 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the API functions."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.api import get_advice
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+
+
+def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
+ """Test getting advice when no target provided."""
+ with pytest.raises(Exception, match="Target is not provided"):
+ get_advice(None, test_keras_model, "all") # type: ignore
+
+
+def test_get_advice_wrong_category(test_keras_model: Path) -> None:
+ """Test getting advice when wrong advice category provided."""
+ with pytest.raises(Exception, match="Invalid advice category unknown"):
+ get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore
+
+
+@pytest.mark.parametrize(
+ "category, context, expected_category",
+ [
+ [
+ "all",
+ None,
+ AdviceCategory.ALL,
+ ],
+ [
+ "optimization",
+ None,
+ AdviceCategory.OPTIMIZATION,
+ ],
+ [
+ "operators",
+ None,
+ AdviceCategory.OPERATORS,
+ ],
+ [
+ "performance",
+ None,
+ AdviceCategory.PERFORMANCE,
+ ],
+ [
+ "all",
+ ExecutionContext(),
+ AdviceCategory.ALL,
+ ],
+ [
+ "all",
+ ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
+ AdviceCategory.PERFORMANCE,
+ ],
+ [
+ "all",
+ ExecutionContext(config_parameters={"param": "value"}),
+ AdviceCategory.ALL,
+ ],
+ [
+ "all",
+ ExecutionContext(event_handlers=[MagicMock()]),
+ AdviceCategory.ALL,
+ ],
+ ],
+)
+def test_get_advice(
+ monkeypatch: pytest.MonkeyPatch,
+ category: str,
+ context: ExecutionContext,
+ expected_category: AdviceCategory,
+ test_keras_model: Path,
+) -> None:
+ """Test getting advice with valid parameters."""
+ advisor_mock = MagicMock()
+ monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock))
+
+ get_advice(
+ "ethos-u55-256",
+ test_keras_model,
+ category, # type: ignore
+ context=context,
+ )
+
+ advisor_mock.run.assert_called_once()
+ context = advisor_mock.run.mock_calls[0].args[0]
+ assert isinstance(context, Context)
+ assert context.advice_category == expected_category
+
+ assert context.event_handlers is not None
+ assert context.config_parameters is not None
diff --git a/tests/mlia/test_cli_commands.py b/tests/mlia/test_cli_commands.py
new file mode 100644
index 0000000..bf17339
--- /dev/null
+++ b/tests/mlia/test_cli_commands.py
@@ -0,0 +1,204 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.commands module."""
+from pathlib import Path
+from typing import Any
+from typing import Optional
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.commands import backend
+from mlia.cli.commands import operators
+from mlia.cli.commands import optimization
+from mlia.cli.commands import performance
+from mlia.core.context import ExecutionContext
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.tools.metadata.common import InstallationManager
+
+
+def test_operators_expected_parameters(dummy_context: ExecutionContext) -> None:
+ """Test operators command wrong parameters."""
+ with pytest.raises(Exception, match="Model is not provided"):
+ operators(dummy_context, "ethos-u55-256")
+
+
+def test_performance_unknown_target(
+ dummy_context: ExecutionContext, test_tflite_model: Path
+) -> None:
+ """Test that command should fail if unknown target passed."""
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ performance(
+ dummy_context, model=str(test_tflite_model), target_profile="unknown"
+ )
+
+
+@pytest.mark.parametrize(
+ "target_profile, optimization_type, optimization_target, expected_error",
+ [
+ [
+ "ethos-u55-256",
+ None,
+ "0.5",
+ pytest.raises(Exception, match="Optimization type is not provided"),
+ ],
+ [
+ "ethos-u65-512",
+ "unknown",
+ "16",
+ pytest.raises(Exception, match="Unsupported optimization type: unknown"),
+ ],
+ [
+ "ethos-u55-256",
+ "pruning",
+ None,
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ ],
+ [
+ "ethos-u65-512",
+ "clustering",
+ None,
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ ],
+ [
+ "unknown",
+ "clustering",
+ "16",
+ pytest.raises(Exception, match="Unable to find target profile unknown"),
+ ],
+ ],
+)
+def test_opt_expected_parameters(
+ dummy_context: ExecutionContext,
+ target_profile: str,
+ monkeypatch: pytest.MonkeyPatch,
+ optimization_type: str,
+ optimization_target: str,
+ expected_error: Any,
+ test_keras_model: Path,
+) -> None:
+ """Test that command should fail if no or unknown optimization type provided."""
+ mock_performance_estimation(monkeypatch)
+
+ with expected_error:
+ optimization(
+ ctx=dummy_context,
+ target_profile=target_profile,
+ model=str(test_keras_model),
+ optimization_type=optimization_type,
+ optimization_target=optimization_target,
+ )
+
+
+@pytest.mark.parametrize(
+ "target_profile, optimization_type, optimization_target",
+ [
+ ["ethos-u55-256", "pruning", "0.5"],
+ ["ethos-u65-512", "clustering", "32"],
+ ["ethos-u55-256", "pruning,clustering", "0.5,32"],
+ ],
+)
+def test_opt_valid_optimization_target(
+ target_profile: str,
+ dummy_context: ExecutionContext,
+ optimization_type: str,
+ optimization_target: str,
+ monkeypatch: pytest.MonkeyPatch,
+ test_keras_model: Path,
+) -> None:
+ """Test that command should not fail with valid optimization targets."""
+ mock_performance_estimation(monkeypatch)
+
+ optimization(
+ ctx=dummy_context,
+ target_profile=target_profile,
+ model=str(test_keras_model),
+ optimization_type=optimization_type,
+ optimization_target=optimization_target,
+ )
+
+
+def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Mock performance estimation."""
+ metrics = PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ MemoryUsage(1, 2, 3, 4, 5),
+ )
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate",
+ MagicMock(return_value=metrics),
+ )
+
+
+@pytest.fixture(name="installation_manager_mock")
+def fixture_mock_installation_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
+ """Mock installation manager."""
+ install_manager_mock = MagicMock(spec=InstallationManager)
+ monkeypatch.setattr(
+ "mlia.cli.commands.get_installation_manager",
+ MagicMock(return_value=install_manager_mock),
+ )
+ return install_manager_mock
+
+
+def test_backend_command_action_status(installation_manager_mock: MagicMock) -> None:
+ """Test backend command "status"."""
+ backend(backend_action="status")
+
+ installation_manager_mock.show_env_details.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "i_agree_to_the_contained_eula, backend_name, expected_calls",
+ [
+ [False, None, [call(None, True)]],
+ [True, None, [call(None, False)]],
+ [False, "backend_name", [call("backend_name", True)]],
+ [True, "backend_name", [call("backend_name", False)]],
+ ],
+)
+def test_backend_command_action_add_downoad(
+ installation_manager_mock: MagicMock,
+ i_agree_to_the_contained_eula: bool,
+ backend_name: Optional[str],
+ expected_calls: Any,
+) -> None:
+ """Test backend command "install" with download option."""
+ backend(
+ backend_action="install",
+ download=True,
+ name=backend_name,
+ i_agree_to_the_contained_eula=i_agree_to_the_contained_eula,
+ )
+
+ assert installation_manager_mock.download_and_install.mock_calls == expected_calls
+
+
+@pytest.mark.parametrize("backend_name", [None, "backend_name"])
+def test_backend_command_action_install_from_path(
+ installation_manager_mock: MagicMock,
+ tmp_path: Path,
+ backend_name: Optional[str],
+) -> None:
+ """Test backend command "install" with backend path."""
+ backend(backend_action="install", path=tmp_path, name=backend_name)
+
+ installation_manager_mock.install_from(tmp_path, backend_name)
+
+
+def test_backend_command_action_install_only_one_action(
+ installation_manager_mock: MagicMock, # pylint: disable=unused-argument
+ tmp_path: Path,
+) -> None:
+ """Test that only one of action type allowed."""
+ with pytest.raises(
+ Exception,
+ match="Please select only one action: download or "
+ "provide path to the backend installation",
+ ):
+ backend(backend_action="install", download=True, path=tmp_path)
diff --git a/tests/mlia/test_cli_config.py b/tests/mlia/test_cli_config.py
new file mode 100644
index 0000000..6d19eec
--- /dev/null
+++ b/tests/mlia/test_cli_config.py
@@ -0,0 +1,49 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.config module."""
+from typing import List
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.config import get_default_backends
+from mlia.cli.config import is_corstone_backend
+
+
+@pytest.mark.parametrize(
+ "available_backends, expected_default_backends",
+ [
+ [["Vela"], ["Vela"]],
+ [["Corstone-300"], ["Corstone-300"]],
+ [["Corstone-310"], ["Corstone-310"]],
+ [["Corstone-300", "Corstone-310"], ["Corstone-310"]],
+ [["Vela", "Corstone-300", "Corstone-310"], ["Vela", "Corstone-310"]],
+ [
+ ["Vela", "Corstone-300", "Corstone-310", "New backend"],
+ ["Vela", "Corstone-310", "New backend"],
+ ],
+ [
+ ["Vela", "Corstone-300", "New backend"],
+ ["Vela", "Corstone-300", "New backend"],
+ ],
+ ],
+)
+def test_get_default_backends(
+ monkeypatch: pytest.MonkeyPatch,
+ available_backends: List[str],
+ expected_default_backends: List[str],
+) -> None:
+ """Test function get_default backends."""
+ monkeypatch.setattr(
+ "mlia.cli.config.get_available_backends",
+ MagicMock(return_value=available_backends),
+ )
+
+ assert get_default_backends() == expected_default_backends
+
+
+def test_is_corstone_backend() -> None:
+ """Test function is_corstone_backend."""
+ assert is_corstone_backend("Corstone-300") is True
+ assert is_corstone_backend("Corstone-310") is True
+ assert is_corstone_backend("New backend") is False
diff --git a/tests/mlia/test_cli_helpers.py b/tests/mlia/test_cli_helpers.py
new file mode 100644
index 0000000..2c52885
--- /dev/null
+++ b/tests/mlia/test_cli_helpers.py
@@ -0,0 +1,165 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the helper classes."""
+from typing import Any
+from typing import Dict
+from typing import List
+
+import pytest
+
+from mlia.cli.helpers import CLIActionResolver
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+class TestCliActionResolver:
+ """Test cli action resolver."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, params, expected_result",
+ [
+ [
+ {},
+ {"opt_settings": "some_setting"},
+ [],
+ ],
+ [
+ {},
+ {},
+ [
+ "Note: you will need a Keras model for that.",
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 "
+ "/path/to/keras_model",
+ "For more info: mlia optimization --help",
+ ],
+ ],
+ [
+ {"model": "model.h5"},
+ {},
+ [
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 model.h5",
+ "For more info: mlia optimization --help",
+ ],
+ ],
+ [
+ {"model": "model.h5"},
+ {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
+ [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.5 model.h5",
+ ],
+ ],
+ [
+ {"model": "model.h5", "target_profile": "target_profile"},
+ {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
+ [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.5 "
+ "--target-profile target_profile model.h5",
+ ],
+ ],
+ ],
+ )
+ def test_apply_optimizations(
+ args: Dict[str, Any],
+ params: Dict[str, Any],
+ expected_result: List[str],
+ ) -> None:
+ """Test action resolving for applying optimizations."""
+ resolver = CLIActionResolver(args)
+ assert resolver.apply_optimizations(**params) == expected_result
+
+ @staticmethod
+ def test_supported_operators_info() -> None:
+ """Test supported operators info."""
+ resolver = CLIActionResolver({})
+ assert resolver.supported_operators_info() == [
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+
+ @staticmethod
+ def test_operator_compatibility_details() -> None:
+ """Test operator compatibility details info."""
+ resolver = CLIActionResolver({})
+ assert resolver.operator_compatibility_details() == [
+ "For more details, run: mlia operators --help"
+ ]
+
+ @staticmethod
+ def test_optimization_details() -> None:
+ """Test optimization details info."""
+ resolver = CLIActionResolver({})
+ assert resolver.optimization_details() == [
+ "For more info, see: mlia optimization --help"
+ ]
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, expected_result",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"model": "model.tflite"},
+ [
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance model.tflite",
+ ],
+ ],
+ [
+ {"model": "model.tflite", "target_profile": "target_profile"},
+ [
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance --target-profile target_profile model.tflite",
+ ],
+ ],
+ ],
+ )
+ def test_check_performance(
+ args: Dict[str, Any], expected_result: List[str]
+ ) -> None:
+ """Test check performance info."""
+ resolver = CLIActionResolver(args)
+ assert resolver.check_performance() == expected_result
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, expected_result",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"model": "model.tflite"},
+ [
+ "Try running the following command to verify that:",
+ "mlia operators model.tflite",
+ ],
+ ],
+ [
+ {"model": "model.tflite", "target_profile": "target_profile"},
+ [
+ "Try running the following command to verify that:",
+ "mlia operators --target-profile target_profile model.tflite",
+ ],
+ ],
+ ],
+ )
+ def test_check_operator_compatibility(
+ args: Dict[str, Any], expected_result: List[str]
+ ) -> None:
+ """Test checking operator compatibility info."""
+ resolver = CLIActionResolver(args)
+ assert resolver.check_operator_compatibility() == expected_result
diff --git a/tests/mlia/test_cli_logging.py b/tests/mlia/test_cli_logging.py
new file mode 100644
index 0000000..7c5f299
--- /dev/null
+++ b/tests/mlia/test_cli_logging.py
@@ -0,0 +1,104 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module cli.logging."""
+import logging
+from pathlib import Path
+from typing import Optional
+
+import pytest
+
+from mlia.cli.logging import setup_logging
+from tests.mlia.utils.logging import clear_loggers
+
+
+def teardown_function() -> None:
+ """Perform action after test completion.
+
+ This function is launched automatically by pytest after each test
+ in this module.
+ """
+ clear_loggers()
+
+
+@pytest.mark.parametrize(
+ "logs_dir, verbose, expected_output, expected_log_file_content",
+ [
+ (
+ None,
+ None,
+ "cli info\n",
+ None,
+ ),
+ (
+ None,
+ True,
+ """mlia.tools.aiet_wrapper - aiet debug
+cli info
+mlia.cli - cli debug
+""",
+ None,
+ ),
+ (
+ "logs",
+ True,
+ """mlia.tools.aiet_wrapper - aiet debug
+cli info
+mlia.cli - cli debug
+""",
+ """mlia.tools.aiet_wrapper - DEBUG - aiet debug
+mlia.cli - DEBUG - cli debug
+""",
+ ),
+ ],
+)
+def test_setup_logging(
+ tmp_path: Path,
+ capfd: pytest.CaptureFixture,
+ logs_dir: str,
+ verbose: bool,
+ expected_output: str,
+ expected_log_file_content: str,
+) -> None:
+ """Test function setup_logging."""
+ logs_dir_path = tmp_path / logs_dir if logs_dir else None
+
+ setup_logging(logs_dir_path, verbose)
+
+ aiet_logger = logging.getLogger("mlia.tools.aiet_wrapper")
+ aiet_logger.debug("aiet debug")
+
+ cli_logger = logging.getLogger("mlia.cli")
+ cli_logger.info("cli info")
+ cli_logger.debug("cli debug")
+
+ stdout, _ = capfd.readouterr()
+ assert stdout == expected_output
+
+ check_log_assertions(logs_dir_path, expected_log_file_content)
+
+
+def check_log_assertions(
+ logs_dir_path: Optional[Path], expected_log_file_content: str
+) -> None:
+ """Test assertions for log file."""
+ if logs_dir_path is not None:
+ assert logs_dir_path.is_dir()
+
+ items = list(logs_dir_path.iterdir())
+ assert len(items) == 1
+
+ log_file_path = items[0]
+ assert log_file_path.is_file()
+
+ log_file_name = log_file_path.name
+ assert log_file_name == "mlia.log"
+
+ with open(log_file_path, encoding="utf-8") as log_file:
+ log_content = log_file.read()
+
+ expected_lines = expected_log_file_content.split("\n")
+ produced_lines = log_content.split("\n")
+
+ assert len(expected_lines) == len(produced_lines)
+ for expected, produced in zip(expected_lines, produced_lines):
+ assert expected in produced
diff --git a/tests/mlia/test_cli_main.py b/tests/mlia/test_cli_main.py
new file mode 100644
index 0000000..a0937d5
--- /dev/null
+++ b/tests/mlia/test_cli_main.py
@@ -0,0 +1,357 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for main module."""
+import argparse
+from functools import wraps
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import List
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+import mlia
+from mlia.cli.main import CommandInfo
+from mlia.cli.main import main
+from mlia.core.context import ExecutionContext
+from tests.mlia.utils.logging import clear_loggers
+
+
+def teardown_function() -> None:
+ """Perform action after test completion.
+
+ This function is launched automatically by pytest after each test
+ in this module.
+ """
+ clear_loggers()
+
+
+def test_option_version(capfd: pytest.CaptureFixture) -> None:
+ """Test --version."""
+ with pytest.raises(SystemExit) as ex:
+ main(["--version"])
+
+ assert ex.type == SystemExit
+ assert ex.value.code == 0
+
+ stdout, stderr = capfd.readouterr()
+ assert len(stdout.splitlines()) == 1
+ assert stderr == ""
+
+
+@pytest.mark.parametrize(
+ "is_default, expected_command_help",
+ [(True, "Test command [default]"), (False, "Test command")],
+)
+def test_command_info(is_default: bool, expected_command_help: str) -> None:
+ """Test properties of CommandInfo object."""
+
+ def test_command() -> None:
+ """Test command."""
+
+ command_info = CommandInfo(test_command, ["test"], [], is_default)
+ assert command_info.command_name == "test_command"
+ assert command_info.command_name_and_aliases == ["test_command", "test"]
+ assert command_info.command_help == expected_command_help
+
+
+def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+ """Test adding default command."""
+
+ def mock_command(
+ func_mock: MagicMock, name: str, with_working_dir: bool
+ ) -> Callable[..., None]:
+ """Mock cli command."""
+
+ def sample_cmd_1(*args: Any, **kwargs: Any) -> None:
+ """Sample command."""
+ func_mock(*args, **kwargs)
+
+ def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None:
+ """Another sample command."""
+ func_mock(ctx=ctx, **kwargs)
+
+ ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1
+ ret_func.__name__ = name
+
+ return ret_func # type: ignore
+
+ default_command = MagicMock()
+ non_default_command = MagicMock()
+
+ def default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for default command."""
+ parser.add_argument("--sample")
+ parser.add_argument("--default_arg", default="123")
+
+ def non_default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--param")
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=mock_command(default_command, "default_command", True),
+ aliases=["command1"],
+ opt_groups=[default_command_params],
+ is_default=True,
+ ),
+ CommandInfo(
+ func=mock_command(
+ non_default_command, "non_default_command", False
+ ),
+ aliases=["command2"],
+ opt_groups=[non_default_command_params],
+ is_default=False,
+ ),
+ ]
+ ),
+ )
+
+ tmp_working_dir = str(tmp_path)
+ main(["--working-dir", tmp_working_dir, "--sample", "1"])
+ main(["command2", "--param", "test"])
+
+ default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123")
+ non_default_command.assert_called_once_with(param="test")
+
+
+@pytest.mark.parametrize(
+ "params, expected_call",
+ [
+ [
+ ["operators", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["ops", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "--supported-ops-report"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=True,
+ ),
+ ],
+ [
+ [
+ "all_tests",
+ "sample_model.h5",
+ "--optimization-type",
+ "pruning",
+ "--optimization-target",
+ "0.5",
+ ],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning",
+ optimization_target="0.5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["performance", "sample_model.h5", "--output", "result.json"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ output="result.json",
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.h5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["some_backend"],
+ ),
+ ],
+ ],
+)
+def test_commands_execution(
+ monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any
+) -> None:
+ """Test calling commands from the main function."""
+ mock = MagicMock()
+
+ def wrap_mock_command(command: Callable) -> Callable:
+ """Wrap the command with the mock."""
+
+ @wraps(command)
+ def mock_command(*args: Any, **kwargs: Any) -> Any:
+ """Mock the command."""
+ mock(*args, **kwargs)
+
+ return mock_command
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
+ )
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_available_backends",
+ MagicMock(return_value=["Vela", "some_backend"]),
+ )
+
+ for command in ["all_tests", "operators", "performance", "optimization"]:
+ monkeypatch.setattr(
+ f"mlia.cli.main.{command}",
+ wrap_mock_command(getattr(mlia.cli.main, command)),
+ )
+
+ main(params)
+
+ mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs)
+
+
+@pytest.mark.parametrize(
+ "verbose, exc_mock, expected_output",
+ [
+ [
+ True,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details, or enable verbose mode",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=KeyboardInterrupt()),
+ ["Execution has been interrupted"],
+ ],
+ ],
+)
+def test_verbose_output(
+ monkeypatch: pytest.MonkeyPatch,
+ capsys: pytest.CaptureFixture,
+ verbose: bool,
+ exc_mock: MagicMock,
+ expected_output: List[str],
+) -> None:
+ """Test flag --verbose."""
+
+ def command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--verbose", action="store_true")
+
+ def command() -> None:
+ """Run test command."""
+ exc_mock()
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=command,
+ aliases=["command"],
+ opt_groups=[command_params],
+ ),
+ ]
+ ),
+ )
+
+ params = ["command"]
+ if verbose:
+ params.append("--verbose")
+
+ exit_code = main(params)
+ assert exit_code == 1
+
+ stdout, _ = capsys.readouterr()
+ for expected_message in expected_output:
+ assert expected_message in stdout
diff --git a/tests/mlia/test_cli_options.py b/tests/mlia/test_cli_options.py
new file mode 100644
index 0000000..a441e58
--- /dev/null
+++ b/tests/mlia/test_cli_options.py
@@ -0,0 +1,186 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module options."""
+import argparse
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.cli.options import add_output_options
+from mlia.cli.options import get_target_profile_opts
+from mlia.cli.options import parse_optimization_parameters
+
+
+@pytest.mark.parametrize(
+ "optimization_type, optimization_target, expected_error, expected_result",
+ [
+ (
+ "pruning",
+ "0.5",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ "clustering",
+ "32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ "pruning,clustering",
+ "0.5,32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ (
+ "pruning, clustering",
+ "0.5, 32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ (
+ "pruning,clustering",
+ "0.5",
+ pytest.raises(
+ Exception, match="Wrong number of optimization targets and types"
+ ),
+ None,
+ ),
+ (
+ "",
+ "0.5",
+ pytest.raises(Exception, match="Optimization type is not provided"),
+ None,
+ ),
+ (
+ "pruning,clustering",
+ "",
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ None,
+ ),
+ (
+ "pruning,",
+ "0.5,abc",
+ pytest.raises(
+ Exception, match="Non numeric value for the optimization target"
+ ),
+ None,
+ ),
+ ],
+)
+def test_parse_optimization_parameters(
+ optimization_type: str,
+ optimization_target: str,
+ expected_error: Any,
+ expected_result: Any,
+) -> None:
+ """Test function parse_optimization_parameters."""
+ with expected_error:
+ result = parse_optimization_parameters(optimization_type, optimization_target)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "args, expected_opts",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"target_profile": "profile"},
+ ["--target-profile", "profile"],
+ ],
+ [
+ # for the default profile empty list should be returned
+ {"target": "ethos-u55-256"},
+ [],
+ ],
+ ],
+)
+def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None:
+ """Test getting target options."""
+ assert get_target_profile_opts(args) == expected_opts
+
+
+@pytest.mark.parametrize(
+ "output_parameters, expected_path",
+ [
+ [["--output", "report.json"], "report.json"],
+ [["--output", "REPORT.JSON"], "REPORT.JSON"],
+ [["--output", "some_folder/report.json"], "some_folder/report.json"],
+ [["--output", "report.csv"], "report.csv"],
+ [["--output", "REPORT.CSV"], "REPORT.CSV"],
+ [["--output", "some_folder/report.csv"], "some_folder/report.csv"],
+ ],
+)
+def test_output_options(output_parameters: List[str], expected_path: str) -> None:
+ """Test output options resolving."""
+ parser = argparse.ArgumentParser()
+ add_output_options(parser)
+
+ args = parser.parse_args(output_parameters)
+ assert args.output == expected_path
+
+
+@pytest.mark.parametrize(
+ "output_filename",
+ [
+ "report.txt",
+ "report.TXT",
+ "report",
+ "report.pdf",
+ ],
+)
+def test_output_options_bad_parameters(
+ output_filename: str, capsys: pytest.CaptureFixture
+) -> None:
+ """Test that args parsing should fail if format is not supported."""
+ parser = argparse.ArgumentParser()
+ add_output_options(parser)
+
+ with pytest.raises(SystemExit):
+ parser.parse_args(["--output", output_filename])
+
+ err_output = capsys.readouterr().err
+ suffix = Path(output_filename).suffix[1:]
+ assert f"Unsupported format '{suffix}'" in err_output
diff --git a/tests/mlia/test_core_advice_generation.py b/tests/mlia/test_core_advice_generation.py
new file mode 100644
index 0000000..05db698
--- /dev/null
+++ b/tests/mlia/test_core_advice_generation.py
@@ -0,0 +1,71 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module advice_generation."""
+from typing import List
+
+import pytest
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import Context
+
+
+def test_advice_generation() -> None:
+ """Test advice generation."""
+
+ class SampleProducer(FactBasedAdviceProducer):
+ """Sample producer."""
+
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Process data."""
+ self.add_advice([f"Advice for {data_item}"])
+
+ producer = SampleProducer()
+ producer.produce_advice(123)
+ producer.produce_advice("hello")
+
+ advice = producer.get_advice()
+ assert advice == [Advice(["Advice for 123"]), Advice(["Advice for hello"])]
+
+
+@pytest.mark.parametrize(
+ "category, expected_advice",
+ [
+ [
+ AdviceCategory.OPERATORS,
+ [Advice(["Good advice!"])],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ [],
+ ],
+ ],
+)
+def test_advice_category_decorator(
+ category: AdviceCategory,
+ expected_advice: List[Advice],
+ dummy_context: Context,
+) -> None:
+ """Test for advice_category decorator."""
+
+ class SampleAdviceProducer(FactBasedAdviceProducer):
+ """Sample advice producer."""
+
+ @advice_category(AdviceCategory.OPERATORS)
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Produce the advice."""
+ self.add_advice(["Good advice!"])
+
+ producer = SampleAdviceProducer()
+ dummy_context.update(
+ advice_category=category, event_handlers=[], config_parameters={}
+ )
+ producer.set_context(dummy_context)
+
+ producer.produce_advice("some_data")
+ advice = producer.get_advice()
+
+ assert advice == expected_advice
diff --git a/tests/mlia/test_core_advisor.py b/tests/mlia/test_core_advisor.py
new file mode 100644
index 0000000..375ff62
--- /dev/null
+++ b/tests/mlia/test_core_advisor.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module advisor."""
+from unittest.mock import MagicMock
+
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.context import Context
+from mlia.core.workflow import WorkflowExecutor
+
+
+def test_inference_advisor_run() -> None:
+ """Test running sample inference advisor."""
+ executor_mock = MagicMock(spec=WorkflowExecutor)
+ context_mock = MagicMock(spec=Context)
+
+ class SampleAdvisor(InferenceAdvisor):
+ """Sample inference advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "sample_advisor"
+
+ @classmethod
+ def description(cls) -> str:
+ """Return description of the advisor."""
+ return "Sample advisor"
+
+ @classmethod
+ def info(cls) -> None:
+ """Print advisor info."""
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor."""
+ return executor_mock
+
+ advisor = SampleAdvisor()
+ advisor.run(context_mock)
+
+ executor_mock.run.assert_called_once()
diff --git a/tests/mlia/test_core_context.py b/tests/mlia/test_core_context.py
new file mode 100644
index 0000000..10015aa
--- /dev/null
+++ b/tests/mlia/test_core_context.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module context."""
+from pathlib import Path
+
+from mlia.core.common import AdviceCategory
+from mlia.core.context import ExecutionContext
+from mlia.core.events import DefaultEventPublisher
+
+
+def test_execution_context(tmpdir: str) -> None:
+ """Test execution context."""
+ publisher = DefaultEventPublisher()
+ category = AdviceCategory.OPERATORS
+
+ context = ExecutionContext(
+ advice_category=category,
+ config_parameters={"param": "value"},
+ working_dir=tmpdir,
+ event_handlers=[],
+ event_publisher=publisher,
+ verbose=True,
+ logs_dir="logs_directory",
+ models_dir="models_directory",
+ )
+
+ assert context.advice_category == category
+ assert context.config_parameters == {"param": "value"}
+ assert context.event_handlers == []
+ assert context.event_publisher == publisher
+ assert context.logs_path == Path(tmpdir) / "logs_directory"
+ expected_model_path = Path(tmpdir) / "models_directory/sample.model"
+ assert context.get_model_path("sample.model") == expected_model_path
+ assert context.verbose is True
+ assert str(context) == (
+ f"ExecutionContext: "
+ f"working_dir={tmpdir}, "
+ "advice_category=OPERATORS, "
+ "config_parameters={'param': 'value'}, "
+ "verbose=True"
+ )
+
+ context_with_default_params = ExecutionContext(working_dir=tmpdir)
+ assert context_with_default_params.advice_category is None
+ assert context_with_default_params.config_parameters is None
+ assert context_with_default_params.event_handlers is None
+ assert isinstance(
+ context_with_default_params.event_publisher, DefaultEventPublisher
+ )
+ assert context_with_default_params.logs_path == Path(tmpdir) / "logs"
+
+ default_model_path = context_with_default_params.get_model_path("sample.model")
+ expected_default_model_path = Path(tmpdir) / "models/sample.model"
+ assert default_model_path == expected_default_model_path
+
+ expected_str = (
+ f"ExecutionContext: working_dir={tmpdir}, "
+ "advice_category=<not set>, "
+ "config_parameters=None, "
+ "verbose=False"
+ )
+ assert str(context_with_default_params) == expected_str
diff --git a/tests/mlia/test_core_data_analysis.py b/tests/mlia/test_core_data_analysis.py
new file mode 100644
index 0000000..a782159
--- /dev/null
+++ b/tests/mlia/test_core_data_analysis.py
@@ -0,0 +1,31 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module data_analysis."""
+from dataclasses import dataclass
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+
+
+def test_fact_extractor() -> None:
+ """Test fact extractor."""
+
+ @dataclass
+ class SampleFact(Fact):
+ """Sample fact."""
+
+ msg: str
+
+ class SampleExtractor(FactExtractor):
+ """Sample extractor."""
+
+ def analyze_data(self, data_item: DataItem) -> None:
+ self.add_fact(SampleFact(f"Fact for {data_item}"))
+
+ extractor = SampleExtractor()
+ extractor.analyze_data(42)
+ extractor.analyze_data("some data")
+
+ facts = extractor.get_analyzed_data()
+ assert facts == [SampleFact("Fact for 42"), SampleFact("Fact for some data")]
diff --git a/tests/mlia/test_core_events.py b/tests/mlia/test_core_events.py
new file mode 100644
index 0000000..faaab7c
--- /dev/null
+++ b/tests/mlia/test_core_events.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module events."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.events import action
+from mlia.core.events import ActionFinishedEvent
+from mlia.core.events import ActionStartedEvent
+from mlia.core.events import DebugEventHandler
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import stage
+from mlia.core.events import SystemEventsHandler
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_event_publisher() -> None:
+ """Test event publishing."""
+ publisher = DefaultEventPublisher()
+ handler_mock1 = MagicMock(spec=EventHandler)
+ handler_mock2 = MagicMock(spec=EventHandler)
+
+ publisher.register_event_handlers([handler_mock1, handler_mock2])
+
+ event = SampleEvent("hello, event!")
+ publisher.publish_event(event)
+
+ handler_mock1.handle_event.assert_called_once_with(event)
+ handler_mock2.handle_event.assert_called_once_with(event)
+
+
+def test_stage_context_manager() -> None:
+ """Test stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ events = (SampleEvent("hello"), SampleEvent("goodbye"))
+ with stage(publisher, events):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = [call(event) for event in events]
+ handler_mock.handle_event.assert_has_calls(calls)
+
+
+def test_action_context_manager() -> None:
+ """Test action stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ with action(publisher, "Sample action"):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = handler_mock.handle_event.mock_calls
+
+ action_started = calls[0].args[0]
+ action_finished = calls[1].args[0]
+
+ assert isinstance(action_started, ActionStartedEvent)
+ assert isinstance(action_finished, ActionFinishedEvent)
+
+ assert action_finished.parent_event_id == action_started.event_id
+
+
+def test_debug_event_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test debugging event handler."""
+ publisher = DefaultEventPublisher()
+
+ publisher.register_event_handler(DebugEventHandler())
+ publisher.register_event_handler(DebugEventHandler(with_stacktrace=True))
+
+ messages = ["Sample event 1", "Sample event 2"]
+ for message in messages:
+ publisher.publish_event(SampleEvent(message))
+
+ captured = capsys.readouterr()
+ for message in messages:
+ assert message in captured.out
+
+ assert "traceback.print_stack" in captured.err
+
+
+def test_event_dispatcher(capsys: pytest.CaptureFixture) -> None:
+ """Test event dispatcher."""
+
+ class SampleEventHandler(EventDispatcher):
+ """Sample event handler."""
+
+ def on_sample_event( # pylint: disable=no-self-use
+ self, _event: SampleEvent
+ ) -> None:
+ """Event handler for SampleEvent."""
+ print("Got sample event")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(SampleEventHandler())
+ publisher.publish_event(SampleEvent("Sample event"))
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Got sample event"
+
+
+def test_system_events_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test system events handler."""
+
+ class CustomSystemEventHandler(SystemEventsHandler):
+ """Custom system event handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ print("Execution started")
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+ print("Execution finished")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(CustomSystemEventHandler())
+
+ publisher.publish_event(ExecutionStartedEvent())
+ publisher.publish_event(SampleEvent("Hello world!"))
+ publisher.publish_event(ExecutionFinishedEvent())
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Execution started\nExecution finished"
+
+
+def test_compare_without_id() -> None:
+ """Test event comparison without event_id."""
+ event1 = SampleEvent("message")
+ event2 = SampleEvent("message")
+
+ assert event1 != event2
+ assert event1.compare_without_id(event2)
+
+ assert not event1.compare_without_id("message") # type: ignore
diff --git a/tests/mlia/test_core_helpers.py b/tests/mlia/test_core_helpers.py
new file mode 100644
index 0000000..8577617
--- /dev/null
+++ b/tests/mlia/test_core_helpers.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the helper classes."""
+from mlia.core.helpers import APIActionResolver
+
+
+def test_api_action_resolver() -> None:
+ """Test APIActionResolver class."""
+ helper = APIActionResolver()
+
+ # pylint: disable=use-implicit-booleaness-not-comparison
+ assert helper.apply_optimizations() == []
+ assert helper.supported_operators_info() == []
+ assert helper.check_performance() == []
+ assert helper.check_operator_compatibility() == []
+ assert helper.operator_compatibility_details() == []
+ assert helper.optimization_details() == []
diff --git a/tests/mlia/test_core_mixins.py b/tests/mlia/test_core_mixins.py
new file mode 100644
index 0000000..d66213d
--- /dev/null
+++ b/tests/mlia/test_core_mixins.py
@@ -0,0 +1,99 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module mixins."""
+import pytest
+
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+from mlia.core.mixins import ContextMixin
+from mlia.core.mixins import ParameterResolverMixin
+
+
+def test_context_mixin(dummy_context: Context) -> None:
+ """Test ContextMixin."""
+
+ class SampleClass(ContextMixin):
+ """Sample class."""
+
+ sample_object = SampleClass()
+ sample_object.set_context(dummy_context)
+ assert sample_object.context == dummy_context
+
+
+class TestParameterResolverMixin:
+ """Tests for parameter resolver mixin."""
+
+ @staticmethod
+ def test_parameter_resolver_mixin(dummy_context: ExecutionContext) -> None:
+ """Test ParameterResolverMixin."""
+
+ class SampleClass(ParameterResolverMixin):
+ """Sample class."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+
+ self.context.update(
+ advice_category=AdviceCategory.OPERATORS,
+ event_handlers=[],
+ config_parameters={"section": {"param": 123}},
+ )
+
+ sample_object = SampleClass()
+ value = sample_object.get_parameter("section", "param")
+ assert value == 123
+
+ with pytest.raises(
+ Exception, match="Parameter param expected to have type <class 'str'>"
+ ):
+ value = sample_object.get_parameter("section", "param", expected_type=str)
+
+ with pytest.raises(Exception, match="Parameter no_param is not set"):
+ value = sample_object.get_parameter("section", "no_param")
+
+ @staticmethod
+ def test_parameter_resolver_mixin_no_config(
+ dummy_context: ExecutionContext,
+ ) -> None:
+ """Test ParameterResolverMixin without config params."""
+
+ class SampleClassNoConfig(ParameterResolverMixin):
+ """Sample context without config params."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+
+ with pytest.raises(Exception, match="Configuration parameters are not set"):
+ sample_object_no_config = SampleClassNoConfig()
+ sample_object_no_config.get_parameter("section", "param", expected_type=str)
+
+ @staticmethod
+ def test_parameter_resolver_mixin_bad_section(
+ dummy_context: ExecutionContext,
+ ) -> None:
+ """Test ParameterResolverMixin without config params."""
+
+ class SampleClassBadSection(ParameterResolverMixin):
+ """Sample context with bad section in config."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+ self.context.update(
+ advice_category=AdviceCategory.OPERATORS,
+ event_handlers=[],
+ config_parameters={"section": ["param"]},
+ )
+
+ with pytest.raises(
+ Exception,
+ match="Parameter section section has wrong format, "
+ "expected to be a dictionary",
+ ):
+ sample_object_bad_section = SampleClassBadSection()
+ sample_object_bad_section.get_parameter(
+ "section", "param", expected_type=str
+ )
diff --git a/tests/mlia/test_core_performance.py b/tests/mlia/test_core_performance.py
new file mode 100644
index 0000000..0d28fe8
--- /dev/null
+++ b/tests/mlia/test_core_performance.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module performance."""
+from pathlib import Path
+
+from mlia.core.performance import estimate_performance
+from mlia.core.performance import PerformanceEstimator
+
+
+def test_estimate_performance(tmp_path: Path) -> None:
+ """Test function estimate_performance."""
+ model_path = tmp_path / "original.tflite"
+
+ class SampleEstimator(PerformanceEstimator[Path, int]):
+ """Sample estimator."""
+
+ def estimate(self, model: Path) -> int:
+ """Estimate performance."""
+ if model.name == "original.tflite":
+ return 1
+
+ return 2
+
+ def optimized_model(_original: Path) -> Path:
+ """Return path to the 'optimized' model."""
+ return tmp_path / "optimized.tflite"
+
+ results = estimate_performance(model_path, SampleEstimator(), [optimized_model])
+ assert results == [1, 2]
diff --git a/tests/mlia/test_core_reporting.py b/tests/mlia/test_core_reporting.py
new file mode 100644
index 0000000..2f7ec22
--- /dev/null
+++ b/tests/mlia/test_core_reporting.py
@@ -0,0 +1,413 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for reporting module."""
+from typing import List
+
+import pytest
+
+from mlia.core.reporting import BytesCell
+from mlia.core.reporting import Cell
+from mlia.core.reporting import ClockCell
+from mlia.core.reporting import Column
+from mlia.core.reporting import CyclesCell
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import SingleRow
+from mlia.core.reporting import Table
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "cell, expected_repr",
+ [
+ (BytesCell(None), ""),
+ (BytesCell(0), "0 bytes"),
+ (BytesCell(1), "1 byte"),
+ (BytesCell(100000), "100,000 bytes"),
+ (ClockCell(None), ""),
+ (ClockCell(0), "0 Hz"),
+ (ClockCell(1), "1 Hz"),
+ (ClockCell(100000), "100,000 Hz"),
+ (CyclesCell(None), ""),
+ (CyclesCell(0), "0 cycles"),
+ (CyclesCell(1), "1 cycle"),
+ (CyclesCell(100000), "100,000 cycles"),
+ ],
+)
+def test_predefined_cell_types(cell: Cell, expected_repr: str) -> None:
+ """Test predefined cell types."""
+ assert str(cell) == expected_repr
+
+
+@pytest.mark.parametrize(
+ "with_notes, expected_text_report",
+ [
+ [
+ True,
+ """
+Sample table:
+┌──────────┬──────────┬──────────┐
+│ Header 1 │ Header 2 │ Header 3 │
+╞══════════╪══════════╪══════════╡
+│ 1 │ 2 │ 3 │
+├──────────┼──────────┼──────────┤
+│ 4 │ 5 │ 123,123 │
+└──────────┴──────────┴──────────┘
+Sample notes
+ """.strip(),
+ ],
+ [
+ False,
+ """
+Sample table:
+┌──────────┬──────────┬──────────┐
+│ Header 1 │ Header 2 │ Header 3 │
+╞══════════╪══════════╪══════════╡
+│ 1 │ 2 │ 3 │
+├──────────┼──────────┼──────────┤
+│ 4 │ 5 │ 123,123 │
+└──────────┴──────────┴──────────┘
+ """.strip(),
+ ],
+ ],
+)
+def test_table_representation(with_notes: bool, expected_text_report: str) -> None:
+ """Test table report representation."""
+
+ def sample_table(with_notes: bool) -> Table:
+ columns = [
+ Column("Header 1", alias="header1", only_for=["plain_text"]),
+ Column("Header 2", alias="header2", fmt=Format(wrap_width=5)),
+ Column("Header 3", alias="header3"),
+ ]
+ rows = [(1, 2, 3), (4, 5, Cell(123123, fmt=Format(str_fmt="0,d")))]
+
+ return Table(
+ columns,
+ rows,
+ name="Sample table",
+ alias="sample_table",
+ notes="Sample notes" if with_notes else None,
+ )
+
+ table = sample_table(with_notes)
+ csv_repr = table.to_csv()
+ assert csv_repr == [["Header 2", "Header 3"], [2, 3], [5, 123123]]
+
+ json_repr = table.to_json()
+ assert json_repr == {
+ "sample_table": [
+ {"header2": 2, "header3": 3},
+ {"header2": 5, "header3": 123123},
+ ]
+ }
+
+ text_report = remove_ascii_codes(table.to_plain_text())
+ assert text_report == expected_text_report
+
+
+def test_csv_nested_table_representation() -> None:
+ """Test representation of the nested tables in csv format."""
+
+ def sample_table(num_of_cols: int) -> Table:
+ columns = [
+ Column("Header 1", alias="header1"),
+ Column("Header 2", alias="header2"),
+ ]
+
+ rows = [
+ (
+ 1,
+ Table(
+ columns=[
+ Column(f"Nested column {i+1}") for i in range(num_of_cols)
+ ],
+ rows=[[f"value{i+1}" for i in range(num_of_cols)]],
+ name="Nested table",
+ ),
+ )
+ ]
+
+ return Table(columns, rows, name="Sample table", alias="sample_table")
+
+ assert sample_table(num_of_cols=2).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, "value1;value2"],
+ ]
+
+ assert sample_table(num_of_cols=1).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, "value1"],
+ ]
+
+ assert sample_table(num_of_cols=0).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, ""],
+ ]
+
+
+@pytest.mark.parametrize(
+ "report, expected_plain_text, expected_json_data, expected_csv_data",
+ [
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem("Item", "item", "item_value"),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+""".strip(),
+ {
+ "sample_report": {"item": "item_value"},
+ },
+ [
+ ("item",),
+ ("item_value",),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [ReportItem("Nested item", "nested_item", "nested_item_value")],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item nested_item_value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": "nested_item_value"},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", "nested_item_value"),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [ReportItem("Nested item", "nested_item", BytesCell(10))],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 bytes
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": {"unit": "bytes", "value": 10}},
+ },
+ },
+ [
+ ("item", "nested_item_value", "nested_item_unit"),
+ ("item_value", 10, "bytes"),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item",
+ "nested_item",
+ Cell(
+ 10, fmt=Format(str_fmt=lambda x: f"{x} cell value")
+ ),
+ )
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 cell value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item",
+ "nested_item",
+ Cell(
+ 10, fmt=Format(str_fmt=lambda x: f"{x} cell value")
+ ),
+ )
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 cell value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem("Nested item", "nested_item", Cell(10)),
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item", "nested_item", Cell(10, fmt=Format())
+ ),
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ ],
+)
+def test_nested_report_representation(
+ report: NestedReport,
+ expected_plain_text: str,
+ expected_json_data: dict,
+ expected_csv_data: List,
+) -> None:
+ """Test representation of the NestedReport."""
+ plain_text = report.to_plain_text()
+ assert plain_text == expected_plain_text
+
+ json_data = report.to_json()
+ assert json_data == expected_json_data
+
+ csv_data = report.to_csv()
+ assert csv_data == expected_csv_data
+
+
+def test_single_row_representation() -> None:
+ """Test representation of the SingleRow."""
+ single_row = SingleRow(
+ columns=[
+ Column("column1", "column1"),
+ ],
+ rows=[("value1", "value2")],
+ name="Single row example",
+ alias="simple_row_example",
+ )
+
+ expected_text = """
+Single row example:
+ column1 value1
+""".strip()
+ assert single_row.to_plain_text() == expected_text
+ assert single_row.to_csv() == [["column1"], ["value1"]]
+ assert single_row.to_json() == {"simple_row_example": [{"column1": "value1"}]}
+
+ with pytest.raises(Exception, match="Table should have only one row"):
+ wrong_single_row = SingleRow(
+ columns=[
+ Column("column1", "column1"),
+ ],
+ rows=[
+ ("value1", "value2"),
+ ("value1", "value2"),
+ ],
+ name="Single row example",
+ alias="simple_row_example",
+ )
+ wrong_single_row.to_plain_text()
diff --git a/tests/mlia/test_core_workflow.py b/tests/mlia/test_core_workflow.py
new file mode 100644
index 0000000..470e572
--- /dev/null
+++ b/tests/mlia/test_core_workflow.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module workflow."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.advice_generation import ContextAwareAdviceProducer
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import ContextAwareDataAnalyzer
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.workflow import DefaultWorkflowExecutor
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_workflow_executor(tmpdir: str) -> None:
+ """Test workflow executor."""
+ handler_mock = MagicMock(spec=EventHandler)
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.return_value = 42
+
+ data_collector_mock_no_value = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_no_value.collect_data.return_value = None
+
+ data_collector_mock_skipped = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_skipped.name.return_value = "skipped_collector"
+ data_collector_mock_skipped.collect_data.side_effect = (
+ FunctionalityNotSupportedError("Error!", "Error!")
+ )
+
+ data_analyzer_mock = MagicMock(spec=ContextAwareDataAnalyzer)
+ data_analyzer_mock.get_analyzed_data.return_value = ["Really good number!"]
+
+ advice_producer_mock1 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock1.get_advice.return_value = Advice(["All good!"])
+
+ advice_producer_mock2 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])]
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ executor = DefaultWorkflowExecutor(
+ context,
+ [
+ data_collector_mock,
+ data_collector_mock_no_value,
+ data_collector_mock_skipped,
+ ],
+ [data_analyzer_mock],
+ [
+ advice_producer_mock1,
+ advice_producer_mock2,
+ ],
+ [SampleEvent("Hello from advisor!")],
+ )
+
+ executor.run()
+
+ data_collector_mock.collect_data.assert_called_once()
+ data_collector_mock_no_value.collect_data.assert_called_once()
+ data_collector_mock_skipped.collect_data.assert_called_once()
+
+ data_analyzer_mock.analyze_data.assert_called_once_with(42)
+
+ advice_producer_mock1.produce_advice.assert_called_once_with("Really good number!")
+ advice_producer_mock1.get_advice.assert_called_once()
+
+ advice_producer_mock2.produce_advice.called_once_with("Really good number!")
+ advice_producer_mock2.get_advice.assert_called_once()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(SampleEvent("Hello from advisor!")),
+ call(DataCollectionStageStartedEvent()),
+ call(CollectedDataEvent(data_item=42)),
+ call(DataCollectorSkippedEvent("skipped_collector", "Error!: Error!")),
+ call(DataCollectionStageFinishedEvent()),
+ call(DataAnalysisStageStartedEvent()),
+ call(AnalyzedDataEvent(data_item="Really good number!")),
+ call(DataAnalysisStageFinishedEvent()),
+ call(AdviceStageStartedEvent()),
+ call(AdviceEvent(advice=Advice(messages=["All good!"]))),
+ call(AdviceEvent(advice=Advice(messages=["Good advice!"]))),
+ call(AdviceStageFinishedEvent()),
+ call(ExecutionFinishedEvent()),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ assert actual_event.compare_without_id(expected_event)
+
+
+def test_workflow_executor_failed(tmpdir: str) -> None:
+ """Test scenario when one of the components raises exception."""
+ handler_mock = MagicMock(spec=EventHandler)
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ collection_exception = Exception("Collection failed")
+
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.side_effect = collection_exception
+
+ executor = DefaultWorkflowExecutor(context, [data_collector_mock], [], [])
+ executor.run()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(DataCollectionStageStartedEvent()),
+ call(ExecutionFailedEvent(collection_exception)),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ if isinstance(actual_event, ExecutionFailedEvent):
+ # seems that dataclass comparison doesn't work well
+ # for the exceptions
+ actual_exception = actual_event.err
+ expected_exception = expected_event.err
+
+ assert actual_exception == expected_exception
+ continue
+
+ assert actual_event.compare_without_id(expected_event)
diff --git a/tests/mlia/test_devices_ethosu_advice_generation.py b/tests/mlia/test_devices_ethosu_advice_generation.py
new file mode 100644
index 0000000..98c8a57
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_advice_generation.py
@@ -0,0 +1,483 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U advice generation."""
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.cli.helpers import CLIActionResolver
+from mlia.core.advice_generation import Advice
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import ExecutionContext
+from mlia.core.helpers import ActionResolver
+from mlia.core.helpers import APIActionResolver
+from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
+from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationDiff
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.devices.ethosu.data_analysis import PerfMetricDiff
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+@pytest.mark.parametrize(
+ "input_data, advice_category, action_resolver, expected_advice",
+ [
+ [
+ AllOperatorsSupportedOnNPU(),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU."
+ ]
+ )
+ ],
+ ],
+ [
+ AllOperatorsSupportedOnNPU(),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver(
+ {
+ "target_profile": "sample_target",
+ "model": "sample_model.tflite",
+ }
+ ),
+ [
+ Advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU.",
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance --target-profile sample_target "
+ "sample_model.tflite",
+ ]
+ )
+ ],
+ ],
+ [
+ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You have at least 3 operators that is CPU only: "
+ "OP1,OP2,OP3.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ ]
+ )
+ ],
+ ],
+ [
+ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver({}),
+ [
+ Advice(
+ [
+ "You have at least 3 operators that is CPU only: "
+ "OP1,OP2,OP3.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+ )
+ ],
+ ],
+ [
+ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You have 40% of operators that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+ ],
+ ],
+ [
+ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver({}),
+ [
+ Advice(
+ [
+ "You have 40% of operators that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6) "
+ "to check if those results can be further improved.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ CLIActionResolver({"model": "sample_model.h5"}),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6) "
+ "to check if those results can be further improved.",
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.6 sample_model.h5",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("pruning", 0.5, None),
+ OptimizationSettings("clustering", 32, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5, clustering: 32)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6 and/or clustering: 16) "
+ "to check if those results can be further improved.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("clustering", 2, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (clustering: 2)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- DRAM used (KB) have degraded by 50.00%",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "- On chip flash used (KB) have degraded by 50.00%",
+ "- Off chip flash used (KB) have degraded by 50.00%",
+ "- NPU total cycles have degraded by 900.00%",
+ "The performance seems to have degraded after "
+ "applying the selected optimizations, "
+ "try exploring different optimization types/targets.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.6, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [], # no advice for more than one optimization result
+ ],
+ ],
+)
+def test_ethosu_advice_producer(
+ tmpdir: str,
+ input_data: DataItem,
+ expected_advice: List[Advice],
+ advice_category: AdviceCategory,
+ action_resolver: ActionResolver,
+) -> None:
+ """Test Ethos-U Advice producer."""
+ producer = EthosUAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ action_resolver=action_resolver,
+ )
+
+ producer.set_context(context)
+ producer.produce_advice(input_data)
+
+ assert producer.get_advice() == expected_advice
+
+
+@pytest.mark.parametrize(
+ "advice_category, action_resolver, expected_advice",
+ [
+ [
+ None,
+ None,
+ [],
+ ],
+ [
+ AdviceCategory.OPERATORS,
+ None,
+ [],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ ]
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model."
+ ]
+ ),
+ ],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ CLIActionResolver({"model": "test_model.h5"}),
+ [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ "Try running the following command to verify that:",
+ "mlia operators test_model.h5",
+ ]
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model.",
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 "
+ "test_model.h5",
+ "For more info: mlia optimization --help",
+ ]
+ ),
+ ],
+ ],
+ [
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ ]
+ )
+ ],
+ ],
+ [
+ AdviceCategory.OPTIMIZATION,
+ CLIActionResolver({"model": "test_model.h5"}),
+ [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ "For more details, run: mlia operators --help",
+ ]
+ )
+ ],
+ ],
+ ],
+)
+def test_ethosu_static_advice_producer(
+ tmpdir: str,
+ advice_category: Optional[AdviceCategory],
+ action_resolver: ActionResolver,
+ expected_advice: List[Advice],
+) -> None:
+ """Test static advice generation."""
+ producer = EthosUStaticAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ action_resolver=action_resolver,
+ )
+ producer.set_context(context)
+ assert producer.get_advice() == expected_advice
diff --git a/tests/mlia/test_devices_ethosu_advisor.py b/tests/mlia/test_devices_ethosu_advisor.py
new file mode 100644
index 0000000..74d2408
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_advisor.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U MLIA module."""
+from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
+
+
+def test_advisor_metadata() -> None:
+ """Test advisor metadata."""
+ assert EthosUInferenceAdvisor.name() == "ethos_u_inference_advisor"
diff --git a/tests/mlia/test_devices_ethosu_config.py b/tests/mlia/test_devices_ethosu_config.py
new file mode 100644
index 0000000..49c999a
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_config.py
@@ -0,0 +1,124 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for config module."""
+from contextlib import ExitStack as does_not_raise
+from typing import Any
+from typing import Dict
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.config import get_target
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.filesystem import get_vela_config
+
+
+def test_compiler_options_default_init() -> None:
+ """Test compiler options default init."""
+ opts = VelaCompilerOptions()
+
+ assert opts.config_files is None
+ assert opts.system_config == "internal-default"
+ assert opts.memory_mode == "internal-default"
+ assert opts.accelerator_config is None
+ assert opts.max_block_dependency == 3
+ assert opts.arena_cache_size is None
+ assert opts.tensor_allocator == "HillClimb"
+ assert opts.cpu_tensor_alignment == 16
+ assert opts.optimization_strategy == "Performance"
+ assert opts.output_dir is None
+
+
+def test_ethosu_target() -> None:
+ """Test Ethos-U target configuration init."""
+ default_config = EthosUConfiguration("ethos-u55-256")
+
+ assert default_config.target == "ethos-u55"
+ assert default_config.mac == 256
+ assert default_config.compiler_options is not None
+
+
+def test_get_target() -> None:
+ """Test function get_target."""
+ with pytest.raises(Exception, match="No target profile given"):
+ get_target(None) # type: ignore
+
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ get_target("unknown")
+
+ u65_device = get_target("ethos-u65-512")
+
+ assert isinstance(u65_device, EthosUConfiguration)
+ assert u65_device.target == "ethos-u65"
+ assert u65_device.mac == 512
+ assert u65_device.compiler_options.accelerator_config == "ethos-u65-512"
+ assert u65_device.compiler_options.memory_mode == "Dedicated_Sram"
+ assert u65_device.compiler_options.config_files == str(get_vela_config())
+
+
+@pytest.mark.parametrize(
+ "profile_data, expected_error",
+ [
+ [
+ {},
+ pytest.raises(
+ Exception,
+ match="Mandatory fields missing from target profile: "
+ r"\['mac', 'memory_mode', 'system_config', 'target'\]",
+ ),
+ ],
+ [
+ {"target": "ethos-u65", "mac": 512},
+ pytest.raises(
+ Exception,
+ match="Mandatory fields missing from target profile: "
+ r"\['memory_mode', 'system_config'\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u65",
+ "mac": 2,
+ "system_config": "Ethos_U65_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ pytest.raises(
+ Exception,
+ match=r"Mac value for selected device should be in \[256, 512\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u55",
+ "mac": 1,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ pytest.raises(
+ Exception,
+ match="Mac value for selected device should be "
+ r"in \[32, 64, 128, 256\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u65",
+ "mac": 512,
+ "system_config": "Ethos_U65_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ does_not_raise(),
+ ],
+ ],
+)
+def test_ethosu_configuration(
+ monkeypatch: pytest.MonkeyPatch, profile_data: Dict[str, Any], expected_error: Any
+) -> None:
+ """Test creating Ethos-U configuration."""
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.config.get_profile", MagicMock(return_value=profile_data)
+ )
+
+ with expected_error:
+ EthosUConfiguration("target")
diff --git a/tests/mlia/test_devices_ethosu_data_analysis.py b/tests/mlia/test_devices_ethosu_data_analysis.py
new file mode 100644
index 0000000..4b1d38b
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_data_analysis.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U data analysis module."""
+from typing import List
+
+import pytest
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationDiff
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.devices.ethosu.data_analysis import PerfMetricDiff
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+
+
+def test_perf_metrics_diff() -> None:
+ """Test PerfMetricsDiff class."""
+ diff_same = PerfMetricDiff(1, 1)
+ assert diff_same.same is True
+ assert diff_same.improved is False
+ assert diff_same.degraded is False
+ assert diff_same.diff == 0
+
+ diff_improved = PerfMetricDiff(10, 5)
+ assert diff_improved.same is False
+ assert diff_improved.improved is True
+ assert diff_improved.degraded is False
+ assert diff_improved.diff == 50.0
+
+ diff_degraded = PerfMetricDiff(5, 10)
+ assert diff_degraded.same is False
+ assert diff_degraded.improved is False
+ assert diff_degraded.degraded is True
+ assert diff_degraded.diff == -100.0
+
+ diff_original_zero = PerfMetricDiff(0, 1)
+ assert diff_original_zero.diff == 0
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_facts",
+ [
+ [
+ Operators(
+ [
+ Operator(
+ "CPU operator",
+ "CPU operator type",
+ NpuSupported(False, [("CPU only operator", "")]),
+ )
+ ]
+ ),
+ [
+ HasCPUOnlyOperators(["CPU operator type"]),
+ HasUnsupportedOnNPUOperators(1.0),
+ ],
+ ],
+ [
+ Operators(
+ [
+ Operator(
+ "NPU operator",
+ "NPU operator type",
+ NpuSupported(True, []),
+ )
+ ]
+ ),
+ [
+ AllOperatorsSupportedOnNPU(),
+ ],
+ ],
+ [
+ OptimizationPerformanceMetrics(
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
+ ),
+ [
+ [
+ [
+ OptimizationSettings("pruning", 0.5, None),
+ ],
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(
+ *[i * 1024 for i in range(1, 6)] # type: ignore
+ ),
+ ),
+ ],
+ ],
+ ),
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("pruning", 0.5, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(1.0, 1.0),
+ "dram": PerfMetricDiff(2.0, 2.0),
+ "on_chip_flash": PerfMetricDiff(4.0, 4.0),
+ "off_chip_flash": PerfMetricDiff(5.0, 5.0),
+ "npu_total_cycles": PerfMetricDiff(3, 3),
+ },
+ )
+ ]
+ )
+ ],
+ ],
+ [
+ OptimizationPerformanceMetrics(
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
+ ),
+ [],
+ ),
+ [],
+ ],
+ ],
+)
+def test_ethos_u_data_analyzer(
+ input_data: DataItem, expected_facts: List[Fact]
+) -> None:
+ """Test Ethos-U data analyzer."""
+ analyzer = EthosUDataAnalyzer()
+ analyzer.analyze_data(input_data)
+ assert analyzer.get_analyzed_data() == expected_facts
diff --git a/tests/mlia/test_devices_ethosu_data_collection.py b/tests/mlia/test_devices_ethosu_data_collection.py
new file mode 100644
index 0000000..897cf41
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_data_collection.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the data collection module for Ethos-U."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import Context
+from mlia.core.data_collection import DataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
+from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
+from mlia.devices.ethosu.data_collection import EthosUPerformance
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import Operators
+
+
+@pytest.mark.parametrize(
+ "collector, expected_name",
+ [
+ (
+ EthosUOperatorCompatibility,
+ "ethos_u_operator_compatibility",
+ ),
+ (
+ EthosUPerformance,
+ "ethos_u_performance",
+ ),
+ (
+ EthosUOptimizationPerformance,
+ "ethos_u_model_optimizations",
+ ),
+ ],
+)
+def test_collectors_metadata(
+ collector: DataCollector,
+ expected_name: str,
+) -> None:
+ """Test collectors metadata."""
+ assert collector.name() == expected_name
+
+
+def test_operator_compatibility_collector(
+ dummy_context: Context, test_tflite_model: Path
+) -> None:
+ """Test operator compatibility data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ collector = EthosUOperatorCompatibility(test_tflite_model, device)
+ collector.set_context(dummy_context)
+
+ result = collector.collect_data()
+ assert isinstance(result, Operators)
+
+
+def test_performance_collector(
+ monkeypatch: pytest.MonkeyPatch, dummy_context: Context, test_tflite_model: Path
+) -> None:
+ """Test performance data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ mock_performance_estimation(monkeypatch, device)
+
+ collector = EthosUPerformance(test_tflite_model, device)
+ collector.set_context(dummy_context)
+
+ result = collector.collect_data()
+ assert isinstance(result, PerformanceMetrics)
+
+
+def test_optimization_performance_collector(
+ monkeypatch: pytest.MonkeyPatch,
+ dummy_context: Context,
+ test_keras_model: Path,
+ test_tflite_model: Path,
+) -> None:
+ """Test optimization performance data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ mock_performance_estimation(monkeypatch, device)
+ collector = EthosUOptimizationPerformance(
+ test_keras_model,
+ device,
+ [
+ [
+ {"optimization_type": "pruning", "optimization_target": 0.5},
+ ]
+ ],
+ )
+ collector.set_context(dummy_context)
+ result = collector.collect_data()
+
+ assert isinstance(result, OptimizationPerformanceMetrics)
+ assert isinstance(result.original_perf_metrics, PerformanceMetrics)
+ assert isinstance(result.optimizations_perf_metrics, list)
+ assert len(result.optimizations_perf_metrics) == 1
+
+ opt, metrics = result.optimizations_perf_metrics[0]
+ assert opt == [OptimizationSettings("pruning", 0.5, None)]
+ assert isinstance(metrics, PerformanceMetrics)
+
+ collector_no_optimizations = EthosUOptimizationPerformance(
+ test_keras_model,
+ device,
+ [],
+ )
+ with pytest.raises(FunctionalityNotSupportedError):
+ collector_no_optimizations.collect_data()
+
+ collector_tflite = EthosUOptimizationPerformance(
+ test_tflite_model,
+ device,
+ [
+ [
+ {"optimization_type": "pruning", "optimization_target": 0.5},
+ ]
+ ],
+ )
+ collector_tflite.set_context(dummy_context)
+ with pytest.raises(FunctionalityNotSupportedError):
+ collector_tflite.collect_data()
+
+ with pytest.raises(
+ Exception, match="Optimization parameters expected to be a list"
+ ):
+ collector_bad_config = EthosUOptimizationPerformance(
+ test_keras_model, device, {"optimization_type": "pruning"} # type: ignore
+ )
+ collector.set_context(dummy_context)
+ collector_bad_config.collect_data()
+
+
+def mock_performance_estimation(
+ monkeypatch: pytest.MonkeyPatch, device: EthosUConfiguration
+) -> None:
+ """Mock performance estimation."""
+ metrics = PerformanceMetrics(
+ device,
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ MemoryUsage(1, 2, 3, 4, 5),
+ )
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate",
+ MagicMock(return_value=metrics),
+ )
diff --git a/tests/mlia/test_devices_ethosu_performance.py b/tests/mlia/test_devices_ethosu_performance.py
new file mode 100644
index 0000000..e27efa0
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_performance.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Performance estimation tests."""
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.ethosu.performance import MemorySizeType
+from mlia.devices.ethosu.performance import MemoryUsage
+
+
+def test_memory_usage_conversion() -> None:
+ """Test MemoryUsage objects conversion."""
+ memory_usage_in_kb = MemoryUsage(1, 2, 3, 4, 5, MemorySizeType.KILOBYTES)
+ assert memory_usage_in_kb.in_kilobytes() == memory_usage_in_kb
+
+ memory_usage_in_bytes = MemoryUsage(
+ 1 * 1024, 2 * 1024, 3 * 1024, 4 * 1024, 5 * 1024
+ )
+ assert memory_usage_in_bytes.in_kilobytes() == memory_usage_in_kb
+
+
+def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Mock performance estimation."""
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.estimate_performance",
+ MagicMock(return_value=MagicMock()),
+ )
diff --git a/tests/mlia/test_devices_ethosu_reporters.py b/tests/mlia/test_devices_ethosu_reporters.py
new file mode 100644
index 0000000..2d5905c
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_reporters.py
@@ -0,0 +1,434 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for reports module."""
+import json
+import sys
+from contextlib import ExitStack as doesnt_raise
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Literal
+
+import pytest
+
+from mlia.core.reporting import get_reporter
+from mlia.core.reporting import produce_report
+from mlia.core.reporting import Report
+from mlia.core.reporting import Reporter
+from mlia.core.reporting import Table
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.devices.ethosu.reporters import report_device_details
+from mlia.devices.ethosu.reporters import report_operators
+from mlia.devices.ethosu.reporters import report_perf_metrics
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "data, formatters",
+ [
+ (
+ [Operator("test_operator", "test_type", NpuSupported(False, []))],
+ [report_operators],
+ ),
+ (
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(0, 0, 0, 0, 0, 0),
+ MemoryUsage(0, 0, 0, 0, 0),
+ ),
+ [report_perf_metrics],
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "fmt, output, expected_error",
+ [
+ [
+ "unknown_format",
+ sys.stdout,
+ pytest.raises(Exception, match="Unknown format unknown_format"),
+ ],
+ [
+ "plain_text",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "json",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "csv",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "plain_text",
+ "report.txt",
+ doesnt_raise(),
+ ],
+ [
+ "json",
+ "report.json",
+ doesnt_raise(),
+ ],
+ [
+ "csv",
+ "report.csv",
+ doesnt_raise(),
+ ],
+ ],
+)
+def test_report(
+ data: Any,
+ formatters: List[Callable],
+ fmt: Literal["plain_text", "json", "csv"],
+ output: Any,
+ expected_error: Any,
+ tmp_path: Path,
+) -> None:
+ """Test report function."""
+ if is_file := isinstance(output, str):
+ output = tmp_path / output
+
+ for formatter in formatters:
+ with expected_error:
+ produce_report(data, formatter, fmt, output)
+
+ if is_file:
+ assert output.is_file()
+ assert output.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "ops, expected_plain_text, expected_json_dict, expected_csv_list",
+ [
+ (
+ [
+ Operator(
+ "npu_supported",
+ "test_type",
+ NpuSupported(True, []),
+ ),
+ Operator(
+ "cpu_only",
+ "test_type",
+ NpuSupported(
+ False,
+ [
+ (
+ "CPU only operator",
+ "",
+ ),
+ ],
+ ),
+ ),
+ Operator(
+ "npu_unsupported",
+ "test_type",
+ NpuSupported(
+ False,
+ [
+ (
+ "Not supported operator",
+ "Reason why operator is not supported",
+ )
+ ],
+ ),
+ ),
+ ],
+ """
+Operators:
+┌───┬─────────────────┬───────────────┬───────────┬───────────────────────────────┐
+│ # │ Operator name │ Operator type │ Placement │ Notes │
+╞═══╪═════════════════╪═══════════════╪═══════════╪═══════════════════════════════╡
+│ 1 │ npu_supported │ test_type │ NPU │ │
+├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤
+│ 2 │ cpu_only │ test_type │ CPU │ * CPU only operator │
+├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤
+│ 3 │ npu_unsupported │ test_type │ CPU │ * Not supported operator │
+│ │ │ │ │ │
+│ │ │ │ │ * Reason why operator is not │
+│ │ │ │ │ supported │
+└───┴─────────────────┴───────────────┴───────────┴───────────────────────────────┘
+""".strip(),
+ {
+ "operators": [
+ {
+ "operator_name": "npu_supported",
+ "operator_type": "test_type",
+ "placement": "NPU",
+ "notes": [],
+ },
+ {
+ "operator_name": "cpu_only",
+ "operator_type": "test_type",
+ "placement": "CPU",
+ "notes": [{"note": "CPU only operator"}],
+ },
+ {
+ "operator_name": "npu_unsupported",
+ "operator_type": "test_type",
+ "placement": "CPU",
+ "notes": [
+ {"note": "Not supported operator"},
+ {"note": "Reason why operator is not supported"},
+ ],
+ },
+ ]
+ },
+ [
+ ["Operator name", "Operator type", "Placement", "Notes"],
+ ["npu_supported", "test_type", "NPU", ""],
+ ["cpu_only", "test_type", "CPU", "CPU only operator"],
+ [
+ "npu_unsupported",
+ "test_type",
+ "CPU",
+ "Not supported operator;Reason why operator is not supported",
+ ],
+ ],
+ ),
+ ],
+)
+def test_report_operators(
+ ops: List[Operator],
+ expected_plain_text: str,
+ expected_json_dict: Dict,
+ expected_csv_list: List,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test report_operatos formatter."""
+ # make terminal wide enough to print whole table
+ monkeypatch.setenv("COLUMNS", "100")
+
+ report = report_operators(ops)
+ assert isinstance(report, Table)
+
+ plain_text = remove_ascii_codes(report.to_plain_text())
+ assert plain_text == expected_plain_text
+
+ json_dict = report.to_json()
+ assert json_dict == expected_json_dict
+
+ csv_list = report.to_csv()
+ assert csv_list == expected_csv_list
+
+
+@pytest.mark.parametrize(
+ "device, expected_plain_text, expected_json_dict, expected_csv_list",
+ [
+ [
+ EthosUConfiguration("ethos-u55-256"),
+ """Device information:
+ Target ethos-u55
+ MAC 256
+
+ Memory mode Shared_Sram
+ Const mem area Axi1
+ Arena mem area Axi0
+ Cache mem area Axi0
+ Arena cache size 4,294,967,296 bytes
+
+ System config Ethos_U55_High_End_Embedded
+ Accelerator clock 500,000,000 Hz
+ AXI0 port Sram
+ AXI1 port OffChipFlash
+
+ Memory area settings:
+ Sram:
+ Clock scales 1.0
+ Burst length 32 bytes
+ Read latency 32 cycles
+ Write latency 32 cycles
+
+ Dram:
+ Clock scales 1.0
+ Burst length 1 byte
+ Read latency 0 cycles
+ Write latency 0 cycles
+
+ OnChipFlash:
+ Clock scales 1.0
+ Burst length 1 byte
+ Read latency 0 cycles
+ Write latency 0 cycles
+
+ OffChipFlash:
+ Clock scales 0.125
+ Burst length 128 bytes
+ Read latency 64 cycles
+ Write latency 64 cycles
+
+ Architecture settings:
+ Permanent storage mem area OffChipFlash
+ Feature map storage mem area Sram
+ Fast storage mem area Sram""",
+ {
+ "device": {
+ "target": "ethos-u55",
+ "mac": 256,
+ "memory_mode": {
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": {"value": 4294967296, "unit": "bytes"},
+ },
+ "system_config": {
+ "accelerator_clock": {"value": 500000000.0, "unit": "Hz"},
+ "axi0_port": "Sram",
+ "axi1_port": "OffChipFlash",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 32, "unit": "bytes"},
+ "read_latency": {"value": 32, "unit": "cycles"},
+ "write_latency": {"value": 32, "unit": "cycles"},
+ },
+ "Dram": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 1, "unit": "byte"},
+ "read_latency": {"value": 0, "unit": "cycles"},
+ "write_latency": {"value": 0, "unit": "cycles"},
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 1, "unit": "byte"},
+ "read_latency": {"value": 0, "unit": "cycles"},
+ "write_latency": {"value": 0, "unit": "cycles"},
+ },
+ "OffChipFlash": {
+ "clock_scales": 0.125,
+ "burst_length": {"value": 128, "unit": "bytes"},
+ "read_latency": {"value": 64, "unit": "cycles"},
+ "write_latency": {"value": 64, "unit": "cycles"},
+ },
+ },
+ },
+ "arch_settings": {
+ "permanent_storage_mem_area": "OffChipFlash",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ },
+ }
+ },
+ [
+ (
+ "target",
+ "mac",
+ "memory_mode",
+ "const_mem_area",
+ "arena_mem_area",
+ "cache_mem_area",
+ "arena_cache_size_value",
+ "arena_cache_size_unit",
+ "system_config",
+ "accelerator_clock_value",
+ "accelerator_clock_unit",
+ "axi0_port",
+ "axi1_port",
+ "clock_scales",
+ "burst_length_value",
+ "burst_length_unit",
+ "read_latency_value",
+ "read_latency_unit",
+ "write_latency_value",
+ "write_latency_unit",
+ "permanent_storage_mem_area",
+ "feature_map_storage_mem_area",
+ "fast_storage_mem_area",
+ ),
+ (
+ "ethos-u55",
+ 256,
+ "Shared_Sram",
+ "Axi1",
+ "Axi0",
+ "Axi0",
+ 4294967296,
+ "bytes",
+ "Ethos_U55_High_End_Embedded",
+ 500000000.0,
+ "Hz",
+ "Sram",
+ "OffChipFlash",
+ 0.125,
+ 128,
+ "bytes",
+ 64,
+ "cycles",
+ 64,
+ "cycles",
+ "OffChipFlash",
+ "Sram",
+ "Sram",
+ ),
+ ],
+ ],
+ ],
+)
+def test_report_device_details(
+ device: EthosUConfiguration,
+ expected_plain_text: str,
+ expected_json_dict: Dict,
+ expected_csv_list: List,
+) -> None:
+ """Test report_operatos formatter."""
+ report = report_device_details(device)
+ assert isinstance(report, Report)
+
+ plain_text = report.to_plain_text()
+ assert plain_text == expected_plain_text
+
+ json_dict = report.to_json()
+ assert json_dict == expected_json_dict
+
+ csv_list = report.to_csv()
+ assert csv_list == expected_csv_list
+
+
+def test_get_reporter(tmp_path: Path) -> None:
+ """Test reporter functionality."""
+ ops = Operators(
+ [
+ Operator(
+ "npu_supported",
+ "op_type",
+ NpuSupported(True, []),
+ ),
+ ]
+ )
+
+ output = tmp_path / "output.json"
+ with get_reporter("json", output, find_appropriate_formatter) as reporter:
+ assert isinstance(reporter, Reporter)
+
+ with pytest.raises(
+ Exception, match="Unable to find appropriate formatter for some_data"
+ ):
+ reporter.submit("some_data")
+
+ reporter.submit(ops)
+
+ with open(output, encoding="utf-8") as file:
+ json_data = json.load(file)
+
+ assert json_data == {
+ "operators_stats": [
+ {
+ "npu_unsupported_ratio": 0.0,
+ "num_of_npu_supported_operators": 1,
+ "num_of_operators": 1,
+ }
+ ]
+ }
diff --git a/tests/mlia/test_nn_tensorflow_config.py b/tests/mlia/test_nn_tensorflow_config.py
new file mode 100644
index 0000000..1ac9f97
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_config.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for config module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from mlia.nn.tensorflow.config import get_model
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.config import TfModel
+
+
+def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test Keras to TFLite conversion."""
+ keras_model = KerasModel(test_keras_model)
+
+ tflite_model_path = tmp_path / "test.tflite"
+ keras_model.convert_to_tflite(tflite_model_path)
+
+ assert tflite_model_path.is_file()
+ assert tflite_model_path.stat().st_size > 0
+
+
+def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
+ """Test TensorFlow saved model to TFLite conversion."""
+ tf_model = TfModel(test_tf_model)
+
+ tflite_model_path = tmp_path / "test.tflite"
+ tf_model.convert_to_tflite(tflite_model_path)
+
+ assert tflite_model_path.is_file()
+ assert tflite_model_path.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_type, expected_error",
+ [
+ ("test.tflite", TFLiteModel, does_not_raise()),
+ ("test.h5", KerasModel, does_not_raise()),
+ ("test.hdf5", KerasModel, does_not_raise()),
+ (
+ "test.model",
+ None,
+ pytest.raises(
+ Exception,
+ match="The input model format is not supported"
+ r"\(supported formats: TFLite, Keras, TensorFlow saved model\)!",
+ ),
+ ),
+ ],
+)
+def test_get_model_file(
+ model_path: str, expected_type: type, expected_error: Any
+) -> None:
+ """Test TFLite model type."""
+ with expected_error:
+ model = get_model(model_path)
+ assert isinstance(model, expected_type)
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_type", [("tf_model_test_model", TfModel)]
+)
+def test_get_model_dir(
+ test_models_path: Path, model_path: str, expected_type: type
+) -> None:
+ """Test TFLite model type."""
+ model = get_model(str(test_models_path / model_path))
+ assert isinstance(model, expected_type)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_clustering.py b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py
new file mode 100644
index 0000000..9bcf918
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py
@@ -0,0 +1,131 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module optimizations/clustering."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_tflite_model
+from tests.mlia.utils.common import get_dataset
+from tests.mlia.utils.common import train_model
+
+
+def _prune_model(
+ model: tf.keras.Model, target_sparsity: float, layers_to_prune: Optional[List[str]]
+) -> tf.keras.Model:
+ x_train, y_train = get_dataset()
+ batch_size = 1
+ num_epochs = 1
+
+ pruner = Pruner(
+ model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ x_train,
+ y_train,
+ batch_size,
+ num_epochs,
+ ),
+ )
+ pruner.apply_optimization()
+ pruned_model = pruner.get_model()
+
+ return pruned_model
+
+
+def _test_num_unique_weights(
+ metrics: TFLiteMetrics,
+ target_num_clusters: int,
+ layers_to_cluster: Optional[List[str]],
+) -> None:
+ clustered_uniqueness_dict = metrics.num_unique_weights(
+ ReportClusterMode.NUM_CLUSTERS_PER_AXIS
+ )
+ num_clustered_layers = 0
+ num_optimizable_layers = len(clustered_uniqueness_dict)
+ if layers_to_cluster:
+ expected_num_clustered_layers = len(layers_to_cluster)
+ else:
+ expected_num_clustered_layers = num_optimizable_layers
+ for layer_name in clustered_uniqueness_dict:
+ # the +1 is there temporarily because of a bug that's been fixed
+ # but the fix hasn't been merged yet.
+ # Will need to be removed in the future.
+ if clustered_uniqueness_dict[layer_name][0] <= (target_num_clusters + 1):
+ num_clustered_layers = num_clustered_layers + 1
+ # make sure we are having exactly as many clustered layers as we wanted
+ assert num_clustered_layers == expected_num_clustered_layers
+
+
+def _test_sparsity(
+ metrics: TFLiteMetrics,
+ target_sparsity: float,
+ layers_to_cluster: Optional[List[str]],
+) -> None:
+ pruned_sparsity_dict = metrics.sparsity_per_layer()
+ num_sparse_layers = 0
+ num_optimizable_layers = len(pruned_sparsity_dict)
+ error_margin = 0.03
+ if layers_to_cluster:
+ expected_num_sparse_layers = len(layers_to_cluster)
+ else:
+ expected_num_sparse_layers = num_optimizable_layers
+ for layer_name in pruned_sparsity_dict:
+ if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin:
+ num_sparse_layers = num_sparse_layers + 1
+ # make sure we are having exactly as many sparse layers as we wanted
+ assert num_sparse_layers == expected_num_sparse_layers
+
+
+@pytest.mark.skip(reason="Test fails randomly, further investigation is needed")
+@pytest.mark.parametrize("target_num_clusters", (32, 4))
+@pytest.mark.parametrize("sparsity_aware", (False, True))
+@pytest.mark.parametrize("layers_to_cluster", (["conv1"], ["conv1", "conv2"], None))
+def test_cluster_simple_model_fully(
+ target_num_clusters: int,
+ sparsity_aware: bool,
+ layers_to_cluster: Optional[List[str]],
+ tmp_path: Path,
+ test_keras_model: Path,
+) -> None:
+ """Simple MNIST test to see if clustering works correctly."""
+ target_sparsity = 0.5
+
+ base_model = tf.keras.models.load_model(str(test_keras_model))
+ train_model(base_model)
+
+ if sparsity_aware:
+ base_model = _prune_model(base_model, target_sparsity, layers_to_cluster)
+
+ clusterer = Clusterer(
+ base_model,
+ ClusteringConfiguration(
+ target_num_clusters,
+ layers_to_cluster,
+ ),
+ )
+ clusterer.apply_optimization()
+ clustered_model = clusterer.get_model()
+
+ temp_file = tmp_path / "test_cluster_simple_model_fully_after.tflite"
+ tflite_clustered_model = convert_to_tflite(clustered_model)
+ save_tflite_model(tflite_clustered_model, temp_file)
+ clustered_tflite_metrics = TFLiteMetrics(str(temp_file))
+
+ _test_num_unique_weights(
+ clustered_tflite_metrics, target_num_clusters, layers_to_cluster
+ )
+
+ if sparsity_aware:
+ _test_sparsity(clustered_tflite_metrics, target_sparsity, layers_to_cluster)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_pruning.py b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py
new file mode 100644
index 0000000..64030a6
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py
@@ -0,0 +1,117 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module optimizations/pruning."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import pytest
+import tensorflow as tf
+from numpy.core.numeric import isclose
+
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_tflite_model
+from tests.mlia.utils.common import get_dataset
+from tests.mlia.utils.common import train_model
+
+
+def _test_sparsity(
+ metrics: TFLiteMetrics,
+ target_sparsity: float,
+ layers_to_prune: Optional[List[str]],
+) -> None:
+ pruned_sparsity_dict = metrics.sparsity_per_layer()
+ num_sparse_layers = 0
+ num_optimizable_layers = len(pruned_sparsity_dict)
+ error_margin = 0.03
+ if layers_to_prune:
+ expected_num_sparse_layers = len(layers_to_prune)
+ else:
+ expected_num_sparse_layers = num_optimizable_layers
+ for layer_name in pruned_sparsity_dict:
+ if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin:
+ num_sparse_layers = num_sparse_layers + 1
+ # make sure we are having exactly as many sparse layers as we wanted
+ assert num_sparse_layers == expected_num_sparse_layers
+
+
+def _test_check_sparsity(base_tflite_metrics: TFLiteMetrics) -> None:
+ """Assert the sparsity of a model is zero."""
+ base_sparsity_dict = base_tflite_metrics.sparsity_per_layer()
+ for layer_name, sparsity in base_sparsity_dict.items():
+ assert isclose(
+ sparsity, 0, atol=1e-2
+ ), f"Sparsity for layer '{layer_name}' is {sparsity}, but should be zero."
+
+
+def _get_tflite_metrics(
+ path: Path, tflite_fn: str, model: tf.keras.Model
+) -> TFLiteMetrics:
+ """Save model as TFLiteModel and return metrics."""
+ temp_file = path / tflite_fn
+ save_tflite_model(convert_to_tflite(model), temp_file)
+ return TFLiteMetrics(str(temp_file))
+
+
+@pytest.mark.parametrize("target_sparsity", (0.5, 0.9))
+@pytest.mark.parametrize("mock_data", (False, True))
+@pytest.mark.parametrize("layers_to_prune", (["conv1"], ["conv1", "conv2"], None))
+def test_prune_simple_model_fully(
+ target_sparsity: float,
+ mock_data: bool,
+ layers_to_prune: Optional[List[str]],
+ tmp_path: Path,
+ test_keras_model: Path,
+) -> None:
+ """Simple MNIST test to see if pruning works correctly."""
+ x_train, y_train = get_dataset()
+ batch_size = 1
+ num_epochs = 1
+
+ base_model = tf.keras.models.load_model(str(test_keras_model))
+ train_model(base_model)
+
+ base_tflite_metrics = _get_tflite_metrics(
+ path=tmp_path,
+ tflite_fn="test_prune_simple_model_fully_before.tflite",
+ model=base_model,
+ )
+
+ # Make sure sparsity is zero before pruning
+ _test_check_sparsity(base_tflite_metrics)
+
+ if mock_data:
+ pruner = Pruner(
+ base_model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ ),
+ )
+
+ else:
+ pruner = Pruner(
+ base_model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ x_train,
+ y_train,
+ batch_size,
+ num_epochs,
+ ),
+ )
+
+ pruner.apply_optimization()
+ pruned_model = pruner.get_model()
+
+ pruned_tflite_metrics = _get_tflite_metrics(
+ path=tmp_path,
+ tflite_fn="test_prune_simple_model_fully_after.tflite",
+ model=pruned_model,
+ )
+
+ _test_sparsity(pruned_tflite_metrics, target_sparsity, layers_to_prune)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_select.py b/tests/mlia/test_nn_tensorflow_optimizations_select.py
new file mode 100644
index 0000000..5cac8ba
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_select.py
@@ -0,0 +1,240 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module select."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.optimizations.select import get_optimizer
+from mlia.nn.tensorflow.optimizations.select import MultiStageOptimizer
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+@pytest.mark.parametrize(
+ "config, expected_error, expected_type, expected_config",
+ [
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ PruningConfiguration(0.5),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target should be a "
+ "positive integer. "
+ "Optimization target provided: 0.5",
+ ),
+ None,
+ None,
+ ),
+ (
+ ClusteringConfiguration(32),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="superoptimization",
+ optimization_target="supertarget", # type: ignore
+ layers_to_optimize="all", # type: ignore
+ ),
+ pytest.raises(
+ Exception,
+ match="Unsupported optimization type: superoptimization",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization type is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ "wrong_config",
+ pytest.raises(
+ Exception,
+ match="Unknown optimization configuration wrong_config",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=None, # type: ignore
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ does_not_raise(),
+ MultiStageOptimizer,
+ "pruning: 0.5 - clustering: 32",
+ ),
+ ],
+)
+def test_get_optimizer(
+ config: Any,
+ expected_error: Any,
+ expected_type: type,
+ expected_config: str,
+ test_keras_model: Path,
+) -> None:
+ """Test function get_optimzer."""
+ model = tf.keras.models.load_model(str(test_keras_model))
+
+ with expected_error:
+ optimizer = get_optimizer(model, config)
+ assert isinstance(optimizer, expected_type)
+ assert optimizer.optimization_config() == expected_config
+
+
+@pytest.mark.parametrize(
+ "params, expected_result",
+ [
+ (
+ [],
+ [],
+ ),
+ (
+ [("pruning", 0.5)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ [("pruning", 0.5), ("clustering", 32)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ ],
+)
+def test_optimization_settings_create_from(
+ params: List[Tuple[str, float]], expected_result: List[OptimizationSettings]
+) -> None:
+ """Test creating settings from parsed params."""
+ assert OptimizationSettings.create_from(params) == expected_result
+
+
+@pytest.mark.parametrize(
+ "settings, expected_next_target, expected_error",
+ [
+ [
+ OptimizationSettings("clustering", 32, None),
+ OptimizationSettings("clustering", 16, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 4, None),
+ OptimizationSettings("clustering", 4, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 10, None),
+ OptimizationSettings("clustering", 8, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.5, None),
+ OptimizationSettings("pruning", 0.6, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.9, None),
+ OptimizationSettings("pruning", 0.9, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("super_optimization", 42, None),
+ None,
+ pytest.raises(
+ Exception, match="Unknown optimization type super_optimization"
+ ),
+ ],
+ ],
+)
+def test_optimization_settings_next_target(
+ settings: OptimizationSettings,
+ expected_next_target: OptimizationSettings,
+ expected_error: Any,
+) -> None:
+ """Test getting next optimization target."""
+ with expected_error:
+ assert settings.next_target() == expected_next_target
diff --git a/tests/mlia/test_nn_tensorflow_tflite_metrics.py b/tests/mlia/test_nn_tensorflow_tflite_metrics.py
new file mode 100644
index 0000000..805f7d1
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_tflite_metrics.py
@@ -0,0 +1,137 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/tflite_metrics."""
+import os
+import tempfile
+from math import isclose
+from pathlib import Path
+from typing import Generator
+from typing import List
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+
+
+def _dummy_keras_model() -> tf.keras.Model:
+ # Create a dummy model
+ keras_model = tf.keras.Sequential(
+ [
+ tf.keras.Input(shape=(8, 8, 3)),
+ tf.keras.layers.Conv2D(4, 3),
+ tf.keras.layers.DepthwiseConv2D(3),
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(8),
+ ]
+ )
+ return keras_model
+
+
+def _sparse_binary_keras_model() -> tf.keras.Model:
+ def get_sparse_weights(shape: List[int]) -> np.array:
+ weights = np.zeros(shape)
+ with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator:
+ for idx, value in enumerate(weight_iterator):
+ if idx % 2 == 0:
+ value[...] = 1.0
+ return weights
+
+ keras_model = _dummy_keras_model()
+ # Assign weights to have 0.5 sparsity
+ for layer in keras_model.layers:
+ if not isinstance(layer, tf.keras.layers.Flatten):
+ weight = layer.weights[0]
+ weight.assign(get_sparse_weights(weight.shape))
+ print(layer)
+ print(weight.numpy())
+ return keras_model
+
+
+@pytest.fixture(scope="class", name="tflite_file")
+def fixture_tflite_file() -> Generator:
+ """Generate temporary TFLite file for tests."""
+ converter = tf.lite.TFLiteConverter.from_keras_model(_sparse_binary_keras_model())
+ tflite_model = converter.convert()
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ Path(file).write_bytes(tflite_model)
+ yield file
+
+
+@pytest.fixture(scope="function", name="metrics")
+def fixture_metrics(tflite_file: str) -> TFLiteMetrics:
+ """Generate metrics file for a given TFLite model."""
+ return TFLiteMetrics(tflite_file)
+
+
+class TestTFLiteMetrics:
+ """Tests for module TFLite_metrics."""
+
+ @staticmethod
+ def test_sparsity(metrics: TFLiteMetrics) -> None:
+ """Test sparsity."""
+ # Create new instance with a dummy TFLite file
+ # Check sparsity calculation
+ sparsity_per_layer = metrics.sparsity_per_layer()
+ for name, sparsity in sparsity_per_layer.items():
+ assert isclose(sparsity, 0.5), "Layer '{}' has incorrect sparsity.".format(
+ name
+ )
+ assert isclose(metrics.sparsity_overall(), 0.5)
+
+ @staticmethod
+ def test_clusters(metrics: TFLiteMetrics) -> None:
+ """Test clusters."""
+ # NUM_CLUSTERS_PER_AXIS and NUM_CLUSTERS_MIN_MAX can be handled together
+ for mode in [
+ ReportClusterMode.NUM_CLUSTERS_PER_AXIS,
+ ReportClusterMode.NUM_CLUSTERS_MIN_MAX,
+ ]:
+ num_unique_weights = metrics.num_unique_weights(mode)
+ for name, num_unique_per_axis in num_unique_weights.items():
+ for num_unique in num_unique_per_axis:
+ assert (
+ num_unique == 2
+ ), "Layer '{}' has incorrect number of clusters.".format(name)
+ # NUM_CLUSTERS_HISTOGRAM
+ hists = metrics.num_unique_weights(ReportClusterMode.NUM_CLUSTERS_HISTOGRAM)
+ assert hists
+ for name, hist in hists.items():
+ assert hist
+ for idx, num_axes in enumerate(hist):
+ # The histogram starts with the bin for for num_clusters == 1
+ num_clusters = idx + 1
+ msg = (
+ "Histogram of layer '{}': There are {} axes with {} "
+ "clusters".format(name, num_axes, num_clusters)
+ )
+ if num_clusters == 2:
+ assert num_axes > 0, "{}, but there should be at least one.".format(
+ msg
+ )
+ else:
+ assert num_axes == 0, "{}, but there should be none.".format(msg)
+
+ @staticmethod
+ @pytest.mark.parametrize("report_sparsity", (False, True))
+ @pytest.mark.parametrize("report_cluster_mode", ReportClusterMode)
+ @pytest.mark.parametrize("max_num_clusters", (-1, 8))
+ @pytest.mark.parametrize("verbose", (False, True))
+ def test_summary(
+ tflite_file: str,
+ report_sparsity: bool,
+ report_cluster_mode: ReportClusterMode,
+ max_num_clusters: int,
+ verbose: bool,
+ ) -> None:
+ """Test the summary function."""
+ for metrics in [TFLiteMetrics(tflite_file), TFLiteMetrics(tflite_file, [])]:
+ metrics.summary(
+ report_sparsity=report_sparsity,
+ report_cluster_mode=report_cluster_mode,
+ max_num_clusters=max_num_clusters,
+ verbose=verbose,
+ )
diff --git a/tests/mlia/test_nn_tensorflow_utils.py b/tests/mlia/test_nn_tensorflow_utils.py
new file mode 100644
index 0000000..6d27299
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_utils.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/test_utils."""
+from pathlib import Path
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import get_tf_tensor_shape
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+
+
+def test_convert_to_tflite(test_keras_model: Path) -> None:
+ """Test converting Keras model to TFLite."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+ tflite_model = convert_to_tflite(keras_model)
+
+ assert tflite_model
+
+
+def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test saving Keras model."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+
+ temp_file = tmp_path / "test_model_saving.h5"
+ save_keras_model(keras_model, temp_file)
+ loaded_model = tf.keras.models.load_model(temp_file)
+
+ assert loaded_model.summary() == keras_model.summary()
+
+
+def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test saving TFLite model."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+
+ tflite_model = convert_to_tflite(keras_model)
+
+ temp_file = tmp_path / "test_model_saving.tflite"
+ save_tflite_model(tflite_model, temp_file)
+
+ interpreter = tf.lite.Interpreter(model_path=str(temp_file))
+ assert interpreter
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_result",
+ [
+ [Path("sample_model.tflite"), True],
+ [Path("strange_model.tflite.tfl"), False],
+ [Path("sample_model.h5"), False],
+ [Path("sample_model"), False],
+ ],
+)
+def test_is_tflite_model(model_path: Path, expected_result: bool) -> None:
+ """Test function is_tflite_model."""
+ result = is_tflite_model(model_path)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_result",
+ [
+ [Path("sample_model.h5"), True],
+ [Path("strange_model.h5.keras"), False],
+ [Path("sample_model.tflite"), False],
+ [Path("sample_model"), False],
+ ],
+)
+def test_is_keras_model(model_path: Path, expected_result: bool) -> None:
+ """Test function is_keras_model."""
+ result = is_keras_model(model_path)
+ assert result == expected_result
+
+
+def test_get_tf_tensor_shape(test_tf_model: Path) -> None:
+ """Test get_tf_tensor_shape with test model."""
+ assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1]
diff --git a/tests/mlia/test_resources/vela/sample_vela.ini b/tests/mlia/test_resources/vela/sample_vela.ini
new file mode 100644
index 0000000..c992458
--- /dev/null
+++ b/tests/mlia/test_resources/vela/sample_vela.ini
@@ -0,0 +1,47 @@
+; SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+; SPDX-License-Identifier: Apache-2.0
+; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U55_High_End_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.125
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
+[System_Config.Ethos_U65_High_End]
+core_clock=1e9
+axi0_port=Sram
+axi1_port=Dram
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+Dram_clock_scale=0.234375
+Dram_burst_length=128
+Dram_read_latency=500
+Dram_write_latency=250
+
+; -----------------------------------------------------------------------------
+; Memory Mode
+
+; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
+; The non-SRAM memory is assumed to be read-only
+[Memory_Mode.Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+
+; The SRAM (384KB) is only for use by the Ethos-U
+; The non-SRAM memory is assumed to be read-writeable
+[Memory_Mode.Dedicated_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi1
+cache_mem_area=Axi0
+arena_cache_size=393216
diff --git a/tests/mlia/test_tools_aiet_wrapper.py b/tests/mlia/test_tools_aiet_wrapper.py
new file mode 100644
index 0000000..ab55b71
--- /dev/null
+++ b/tests/mlia/test_tools_aiet_wrapper.py
@@ -0,0 +1,760 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module tools/aiet_wrapper."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+
+from mlia.tools.aiet_wrapper import AIETRunner
+from mlia.tools.aiet_wrapper import DeviceInfo
+from mlia.tools.aiet_wrapper import estimate_performance
+from mlia.tools.aiet_wrapper import ExecutionParams
+from mlia.tools.aiet_wrapper import GenericInferenceOutputParser
+from mlia.tools.aiet_wrapper import GenericInferenceRunnerEthosU
+from mlia.tools.aiet_wrapper import get_aiet_runner
+from mlia.tools.aiet_wrapper import get_generic_runner
+from mlia.tools.aiet_wrapper import get_system_name
+from mlia.tools.aiet_wrapper import is_supported
+from mlia.tools.aiet_wrapper import ModelInfo
+from mlia.tools.aiet_wrapper import PerformanceMetrics
+from mlia.tools.aiet_wrapper import supported_backends
+from mlia.utils.proc import RunningCommand
+
+
+@pytest.mark.parametrize(
+ "data, is_ready, result, missed_keys",
+ [
+ (
+ [],
+ False,
+ {},
+ [
+ "npu_active_cycles",
+ "npu_axi0_rd_data_beat_received",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ),
+ (
+ ["sample text"],
+ False,
+ {},
+ [
+ "npu_active_cycles",
+ "npu_axi0_rd_data_beat_received",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ),
+ (
+ [
+ ["NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 123"],
+ False,
+ {"npu_axi0_rd_data_beat_received": 123},
+ [
+ "npu_active_cycles",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ]
+ ),
+ (
+ [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ "NPU TOTAL cycles: 6",
+ ],
+ True,
+ {
+ "npu_axi0_rd_data_beat_received": 1,
+ "npu_axi0_wr_data_beat_written": 2,
+ "npu_axi1_rd_data_beat_received": 3,
+ "npu_active_cycles": 4,
+ "npu_idle_cycles": 5,
+ "npu_total_cycles": 6,
+ },
+ [],
+ ),
+ ],
+)
+def test_generic_inference_output_parser(
+ data: List[str], is_ready: bool, result: Dict, missed_keys: List[str]
+) -> None:
+ """Test generic runner output parser."""
+ parser = GenericInferenceOutputParser()
+
+ for line in data:
+ parser.feed(line)
+
+ assert parser.is_ready() == is_ready
+ assert parser.result == result
+ assert parser.missed_keys() == missed_keys
+
+
+class TestAIETRunner:
+ """Tests for AIETRunner class."""
+
+ @staticmethod
+ def _setup_aiet(
+ monkeypatch: pytest.MonkeyPatch,
+ available_systems: Optional[List[str]] = None,
+ available_apps: Optional[List[str]] = None,
+ ) -> None:
+ """Set up AIET metadata."""
+
+ def mock_system(system: str) -> MagicMock:
+ """Mock the System instance."""
+ mock = MagicMock()
+ type(mock).name = PropertyMock(return_value=system)
+ return mock
+
+ def mock_app(app: str) -> MagicMock:
+ """Mock the Application instance."""
+ mock = MagicMock()
+ type(mock).name = PropertyMock(return_value=app)
+ mock.can_run_on.return_value = True
+ return mock
+
+ system_mocks = [mock_system(name) for name in (available_systems or [])]
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_available_systems",
+ MagicMock(return_value=system_mocks),
+ )
+
+ apps_mock = [mock_app(name) for name in (available_apps or [])]
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_available_applications",
+ MagicMock(return_value=apps_mock),
+ )
+
+ @pytest.mark.parametrize(
+ "available_systems, system, installed",
+ [
+ ([], "system1", False),
+ (["system1", "system2"], "system1", True),
+ ],
+ )
+ def test_is_system_installed(
+ self,
+ available_systems: List,
+ system: str,
+ installed: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method is_system_installed."""
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ self._setup_aiet(monkeypatch, available_systems)
+
+ assert aiet_runner.is_system_installed(system) == installed
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_systems, systems",
+ [
+ ([], []),
+ (["system1"], ["system1"]),
+ ],
+ )
+ def test_installed_systems(
+ self,
+ available_systems: List[str],
+ systems: List[str],
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method installed_systems."""
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ self._setup_aiet(monkeypatch, available_systems)
+ assert aiet_runner.get_installed_systems() == systems
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ def test_install_system(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test system installation."""
+ install_system_mock = MagicMock()
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.install_system", install_system_mock
+ )
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.install_system(Path("test_system_path"))
+
+ install_system_mock.assert_called_once_with(Path("test_system_path"))
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_systems, systems, expected_result",
+ [
+ ([], [], False),
+ (["system1"], [], False),
+ (["system1"], ["system1"], True),
+ (["system1", "system2"], ["system1", "system3"], False),
+ (["system1", "system2"], ["system1", "system2"], True),
+ ],
+ )
+ def test_systems_installed(
+ self,
+ available_systems: List[str],
+ systems: List[str],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method systems_installed."""
+ self._setup_aiet(monkeypatch, available_systems)
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ assert aiet_runner.systems_installed(systems) is expected_result
+
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, applications, expected_result",
+ [
+ ([], [], False),
+ (["app1"], [], False),
+ (["app1"], ["app1"], True),
+ (["app1", "app2"], ["app1", "app3"], False),
+ (["app1", "app2"], ["app1", "app2"], True),
+ ],
+ )
+ def test_applications_installed(
+ self,
+ available_apps: List[str],
+ applications: List[str],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method applications_installed."""
+ self._setup_aiet(monkeypatch, [], available_apps)
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ assert aiet_runner.applications_installed(applications) is expected_result
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, applications",
+ [
+ ([], []),
+ (
+ ["application1", "application2"],
+ ["application1", "application2"],
+ ),
+ ],
+ )
+ def test_get_installed_applications(
+ self,
+ available_apps: List[str],
+ applications: List[str],
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method get_installed_applications."""
+ mock_executor = MagicMock()
+ self._setup_aiet(monkeypatch, [], available_apps)
+
+ aiet_runner = AIETRunner(mock_executor)
+ assert applications == aiet_runner.get_installed_applications()
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ def test_install_application(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test application installation."""
+ mock_install_application = MagicMock()
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.install_application", mock_install_application
+ )
+
+ mock_executor = MagicMock()
+
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.install_application(Path("test_application_path"))
+ mock_install_application.assert_called_once_with(Path("test_application_path"))
+
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, application, installed",
+ [
+ ([], "system1", False),
+ (
+ ["application1", "application2"],
+ "application1",
+ True,
+ ),
+ (
+ [],
+ "application1",
+ False,
+ ),
+ ],
+ )
+ def test_is_application_installed(
+ self,
+ available_apps: List[str],
+ application: str,
+ installed: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method is_application_installed."""
+ self._setup_aiet(monkeypatch, [], available_apps)
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+ assert installed == aiet_runner.is_application_installed(application, "system1")
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "execution_params, expected_command",
+ [
+ (
+ ExecutionParams("application1", "system1", [], [], []),
+ ["aiet", "application", "run", "-n", "application1", "-s", "system1"],
+ ),
+ (
+ ExecutionParams(
+ "application1",
+ "system1",
+ ["input_file=123.txt", "size=777"],
+ ["param1=456", "param2=789"],
+ ["source1.txt:dest1.txt", "source2.txt:dest2.txt"],
+ ),
+ [
+ "aiet",
+ "application",
+ "run",
+ "-n",
+ "application1",
+ "-s",
+ "system1",
+ "-p",
+ "input_file=123.txt",
+ "-p",
+ "size=777",
+ "--system-param",
+ "param1=456",
+ "--system-param",
+ "param2=789",
+ "--deploy",
+ "source1.txt:dest1.txt",
+ "--deploy",
+ "source2.txt:dest2.txt",
+ ],
+ ),
+ ],
+ )
+ def test_run_application(
+ execution_params: ExecutionParams, expected_command: List[str]
+ ) -> None:
+ """Test method run_application."""
+ mock_executor = MagicMock()
+ mock_running_command = MagicMock()
+ mock_executor.submit.return_value = mock_running_command
+
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.run_application(execution_params)
+
+ mock_executor.submit.assert_called_once_with(expected_command)
+
+
+@pytest.mark.parametrize(
+ "device, system, application, backend, expected_error",
+ [
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", True),
+ "Corstone-300",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", False),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-300: Cortex-M55\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-300: Cortex-M55\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", True),
+ "Corstone-310",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", False),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-310",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-310: Cortex-M85\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-310",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-310: Cortex-M85\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", True),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", True),
+ "Corstone-300",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", False),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-300: Cortex-M55\+Ethos-U65 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", True),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-300: Cortex-M55\+Ethos-U65 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(
+ device_type="unknown_device", # type: ignore
+ mac=None, # type: ignore
+ memory_mode="Shared_Sram",
+ ),
+ ("some_system", False),
+ ("some_application", False),
+ "some backend",
+ pytest.raises(Exception, match="Unsupported device unknown_device"),
+ ),
+ ],
+)
+def test_estimate_performance(
+ device: DeviceInfo,
+ system: Tuple[str, bool],
+ application: Tuple[str, bool],
+ backend: str,
+ expected_error: Any,
+ test_tflite_model: Path,
+ aiet_runner: MagicMock,
+) -> None:
+ """Test getting performance estimations."""
+ system_name, system_installed = system
+ application_name, application_installed = application
+
+ aiet_runner.is_system_installed.return_value = system_installed
+ aiet_runner.is_application_installed.return_value = application_installed
+
+ mock_process = create_mock_process(
+ [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ "NPU TOTAL cycles: 6",
+ ],
+ [],
+ )
+
+ mock_generic_inference_run = RunningCommand(mock_process)
+ aiet_runner.run_application.return_value = mock_generic_inference_run
+
+ with expected_error:
+ perf_metrics = estimate_performance(
+ ModelInfo(test_tflite_model), device, backend
+ )
+
+ assert isinstance(perf_metrics, PerformanceMetrics)
+ assert perf_metrics == PerformanceMetrics(
+ npu_axi0_rd_data_beat_received=1,
+ npu_axi0_wr_data_beat_written=2,
+ npu_axi1_rd_data_beat_received=3,
+ npu_active_cycles=4,
+ npu_idle_cycles=5,
+ npu_total_cycles=6,
+ )
+
+ assert aiet_runner.is_system_installed.called_once_with(system_name)
+ assert aiet_runner.is_application_installed.called_once_with(
+ application_name, system_name
+ )
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_estimate_performance_insufficient_data(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+) -> None:
+ """Test that performance could not be estimated when not all data presented."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = True
+
+ no_total_cycles_output = [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ ]
+ mock_process = create_mock_process(
+ no_total_cycles_output,
+ [],
+ )
+
+ mock_generic_inference_run = RunningCommand(mock_process)
+ aiet_runner.run_application.return_value = mock_generic_inference_run
+
+ with pytest.raises(
+ Exception, match="Unable to get performance metrics, insufficient data"
+ ):
+ device = DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram")
+ estimate_performance(ModelInfo(test_tflite_model), device, backend)
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_estimate_performance_invalid_output(
+ test_tflite_model: Path, aiet_runner: MagicMock, backend: str
+) -> None:
+ """Test estimation could not be done if inference produces unexpected output."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = True
+
+ mock_process = create_mock_process(
+ ["Something", "is", "wrong"], ["What a nice error!"]
+ )
+ aiet_runner.run_application.return_value = RunningCommand(mock_process)
+
+ with pytest.raises(Exception, match="Unable to get performance metrics"):
+ estimate_performance(
+ ModelInfo(test_tflite_model),
+ DeviceInfo(device_type="ethos-u55", mac=256, memory_mode="Shared_Sram"),
+ backend=backend,
+ )
+
+
+def test_get_aiet_runner() -> None:
+ """Test getting aiet runner."""
+ aiet_runner = get_aiet_runner()
+ assert isinstance(aiet_runner, AIETRunner)
+
+
+def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock:
+ """Mock underlying process."""
+ mock_process = MagicMock()
+ mock_process.poll.return_value = 0
+ type(mock_process).stdout = PropertyMock(return_value=iter(stdout))
+ type(mock_process).stderr = PropertyMock(return_value=iter(stderr))
+ return mock_process
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_get_generic_runner(backend: str) -> None:
+ """Test function get_generic_runner()."""
+ device_info = DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram")
+
+ runner = get_generic_runner(device_info=device_info, backend=backend)
+ assert isinstance(runner, GenericInferenceRunnerEthosU)
+
+ with pytest.raises(RuntimeError):
+ get_generic_runner(device_info=device_info, backend="UNKNOWN_BACKEND")
+
+
+@pytest.mark.parametrize(
+ ("backend", "device_type"),
+ (
+ ("Corstone-300", "ethos-u55"),
+ ("Corstone-300", "ethos-u65"),
+ ("Corstone-310", "ethos-u55"),
+ ),
+)
+def test_aiet_backend_support(backend: str, device_type: str) -> None:
+ """Test AIET backend & device support."""
+ assert is_supported(backend)
+ assert is_supported(backend, device_type)
+
+ assert get_system_name(backend, device_type)
+
+ assert backend in supported_backends()
+
+
+class TestGenericInferenceRunnerEthosU:
+ """Test for the class GenericInferenceRunnerEthosU."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "device, backend, expected_system, expected_app",
+ [
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"),
+ "Corstone-310",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Sram"),
+ "Corstone-310",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55 SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55 SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u65", 256, memory_mode="Dedicated_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Generic Inference Runner: Ethos-U65 Dedicated SRAM",
+ ],
+ ],
+ )
+ def test_artifact_resolver(
+ device: DeviceInfo, backend: str, expected_system: str, expected_app: str
+ ) -> None:
+ """Test artifact resolving based on the provided parameters."""
+ generic_runner = get_generic_runner(device, backend)
+ assert isinstance(generic_runner, GenericInferenceRunnerEthosU)
+
+ assert generic_runner.system_name == expected_system
+ assert generic_runner.app_name == expected_app
+
+ @staticmethod
+ def test_artifact_resolver_unsupported_backend() -> None:
+ """Test that it should be not possible to use unsupported backends."""
+ with pytest.raises(
+ RuntimeError, match="Unsupported device ethos-u65 for backend test_backend"
+ ):
+ get_generic_runner(
+ DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), "test_backend"
+ )
+
+ @staticmethod
+ def test_artifact_resolver_unsupported_memory_mode() -> None:
+ """Test that it should be not possible to use unsupported memory modes."""
+ with pytest.raises(
+ RuntimeError, match="Unsupported memory mode test_memory_mode"
+ ):
+ get_generic_runner(
+ DeviceInfo(
+ "ethos-u65",
+ 256,
+ memory_mode="test_memory_mode", # type: ignore
+ ),
+ "Corstone-300",
+ )
+
+ @staticmethod
+ @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+ def test_inference_should_fail_if_system_not_installed(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+ ) -> None:
+ """Test that inference should fail if system is not installed."""
+ aiet_runner.is_system_installed.return_value = False
+
+ generic_runner = get_generic_runner(
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend
+ )
+ with pytest.raises(
+ Exception,
+ match=r"System Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not installed",
+ ):
+ generic_runner.run(ModelInfo(test_tflite_model), [])
+
+ @staticmethod
+ @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+ def test_inference_should_fail_is_apps_not_installed(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+ ) -> None:
+ """Test that inference should fail if apps are not installed."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = False
+
+ generic_runner = get_generic_runner(
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend
+ )
+ with pytest.raises(
+ Exception,
+ match="Application Generic Inference Runner: Ethos-U55/65 Shared SRAM"
+ r" for the system Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not "
+ r"installed",
+ ):
+ generic_runner.run(ModelInfo(test_tflite_model), [])
+
+
+@pytest.fixture(name="aiet_runner")
+def fixture_aiet_runner(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
+ """Mock AIET runner."""
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_aiet_runner",
+ MagicMock(return_value=aiet_runner_mock),
+ )
+ return aiet_runner_mock
diff --git a/tests/mlia/test_tools_metadata_common.py b/tests/mlia/test_tools_metadata_common.py
new file mode 100644
index 0000000..7663b83
--- /dev/null
+++ b/tests/mlia/test_tools_metadata_common.py
@@ -0,0 +1,196 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for commmon installation related functions."""
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from unittest.mock import call
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+
+from mlia.tools.metadata.common import DefaultInstallationManager
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import Installation
+from mlia.tools.metadata.common import InstallationType
+from mlia.tools.metadata.common import InstallFromPath
+
+
+def get_installation_mock(
+ name: str,
+ already_installed: bool = False,
+ could_be_installed: bool = False,
+ supported_install_type: Optional[type] = None,
+) -> MagicMock:
+ """Get mock instance for the installation."""
+ mock = MagicMock(spec=Installation)
+
+ def supports(install_type: InstallationType) -> bool:
+ if supported_install_type is None:
+ return False
+
+ return isinstance(install_type, supported_install_type)
+
+ mock.supports.side_effect = supports
+
+ props = {
+ "name": name,
+ "already_installed": already_installed,
+ "could_be_installed": could_be_installed,
+ }
+ for prop, value in props.items():
+ setattr(type(mock), prop, PropertyMock(return_value=value))
+
+ return mock
+
+
+def _already_installed_mock() -> MagicMock:
+ return get_installation_mock(
+ name="already_installed",
+ already_installed=True,
+ )
+
+
+def _ready_for_installation_mock() -> MagicMock:
+ return get_installation_mock(
+ name="ready_for_installation",
+ already_installed=False,
+ could_be_installed=True,
+ )
+
+
+def _could_be_downloaded_and_installed_mock() -> MagicMock:
+ return get_installation_mock(
+ name="could_be_downloaded_and_installed",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=DownloadAndInstall,
+ )
+
+
+def _could_be_installed_from_mock() -> MagicMock:
+ return get_installation_mock(
+ name="could_be_installed_from",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=InstallFromPath,
+ )
+
+
+def get_installation_manager(
+ noninteractive: bool,
+ installations: List[Any],
+ monkeypatch: pytest.MonkeyPatch,
+ yes_response: bool = True,
+) -> DefaultInstallationManager:
+ """Get installation manager instance."""
+ if not noninteractive:
+ monkeypatch.setattr(
+ "mlia.tools.metadata.common.yes", MagicMock(return_value=yes_response)
+ )
+
+ return DefaultInstallationManager(installations, noninteractive=noninteractive)
+
+
+def test_installation_manager_filtering() -> None:
+ """Test default installation manager."""
+ already_installed = _already_installed_mock()
+ ready_for_installation = _ready_for_installation_mock()
+ could_be_downloaded_and_installed = _could_be_downloaded_and_installed_mock()
+
+ manager = DefaultInstallationManager(
+ [
+ already_installed,
+ ready_for_installation,
+ could_be_downloaded_and_installed,
+ ]
+ )
+ assert manager.already_installed() == [already_installed]
+ assert manager.ready_for_installation() == [
+ ready_for_installation,
+ could_be_downloaded_and_installed,
+ ]
+ assert manager.could_be_downloaded_and_installed() == [
+ could_be_downloaded_and_installed
+ ]
+ assert manager.could_be_downloaded_and_installed("some_installation") == []
+
+
+@pytest.mark.parametrize("noninteractive", [True, False])
+@pytest.mark.parametrize(
+ "install_mock, eula_agreement, backend_name, expected_call",
+ [
+ [
+ _could_be_downloaded_and_installed_mock(),
+ True,
+ None,
+ [call(DownloadAndInstall(eula_agreement=True))],
+ ],
+ [
+ _could_be_downloaded_and_installed_mock(),
+ False,
+ None,
+ [call(DownloadAndInstall(eula_agreement=False))],
+ ],
+ [
+ _could_be_downloaded_and_installed_mock(),
+ False,
+ "unknown",
+ [],
+ ],
+ ],
+)
+def test_installation_manager_download_and_install(
+ install_mock: MagicMock,
+ noninteractive: bool,
+ eula_agreement: bool,
+ backend_name: Optional[str],
+ expected_call: Any,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation process."""
+ install_mock.reset_mock()
+
+ manager = get_installation_manager(noninteractive, [install_mock], monkeypatch)
+
+ manager.download_and_install(backend_name, eula_agreement=eula_agreement)
+ assert install_mock.install.mock_calls == expected_call
+
+
+@pytest.mark.parametrize("noninteractive", [True, False])
+@pytest.mark.parametrize(
+ "install_mock, backend_name, expected_call",
+ [
+ [
+ _could_be_installed_from_mock(),
+ None,
+ [call(InstallFromPath(Path("some_path")))],
+ ],
+ [
+ _could_be_installed_from_mock(),
+ "unknown",
+ [],
+ ],
+ [
+ _already_installed_mock(),
+ "already_installed",
+ [],
+ ],
+ ],
+)
+def test_installation_manager_install_from(
+ install_mock: MagicMock,
+ noninteractive: bool,
+ backend_name: Optional[str],
+ expected_call: Any,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation process."""
+ install_mock.reset_mock()
+
+ manager = get_installation_manager(noninteractive, [install_mock], monkeypatch)
+ manager.install_from(Path("some_path"), backend_name)
+
+ assert install_mock.install.mock_calls == expected_call
diff --git a/tests/mlia/test_tools_metadata_corstone.py b/tests/mlia/test_tools_metadata_corstone.py
new file mode 100644
index 0000000..2ce3610
--- /dev/null
+++ b/tests/mlia/test_tools_metadata_corstone.py
@@ -0,0 +1,419 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Corstone related installation functions.."""
+import tarfile
+from pathlib import Path
+from typing import List
+from typing import Optional
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.tools.aiet_wrapper import AIETRunner
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import InstallFromPath
+from mlia.tools.metadata.corstone import AIETBasedInstallation
+from mlia.tools.metadata.corstone import AIETMetadata
+from mlia.tools.metadata.corstone import BackendInfo
+from mlia.tools.metadata.corstone import BackendInstaller
+from mlia.tools.metadata.corstone import CompoundPathChecker
+from mlia.tools.metadata.corstone import Corstone300Installer
+from mlia.tools.metadata.corstone import get_corstone_installations
+from mlia.tools.metadata.corstone import PackagePathChecker
+from mlia.tools.metadata.corstone import PathChecker
+from mlia.tools.metadata.corstone import StaticPathChecker
+
+
+@pytest.fixture(name="test_mlia_resources")
+def fixture_test_mlia_resources(
+ tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+) -> Path:
+ """Redirect MLIA resources resolution to the temp directory."""
+ mlia_resources = tmp_path / "resources"
+ mlia_resources.mkdir()
+
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.get_mlia_resources",
+ MagicMock(return_value=mlia_resources),
+ )
+
+ return mlia_resources
+
+
+def get_aiet_based_installation( # pylint: disable=too-many-arguments
+ aiet_runner_mock: MagicMock = MagicMock(),
+ name: str = "test_name",
+ description: str = "test_description",
+ download_artifact: Optional[MagicMock] = None,
+ path_checker: PathChecker = MagicMock(),
+ apps_resources: Optional[List[str]] = None,
+ system_config: Optional[str] = None,
+ backend_installer: BackendInstaller = MagicMock(),
+ supported_platforms: Optional[List[str]] = None,
+) -> AIETBasedInstallation:
+ """Get AIET based installation."""
+ return AIETBasedInstallation(
+ aiet_runner=aiet_runner_mock,
+ metadata=AIETMetadata(
+ name=name,
+ description=description,
+ system_config=system_config or "",
+ apps_resources=apps_resources or [],
+ fvp_dir_name="sample_dir",
+ download_artifact=download_artifact,
+ supported_platforms=supported_platforms,
+ ),
+ path_checker=path_checker,
+ backend_installer=backend_installer,
+ )
+
+
+@pytest.mark.parametrize(
+ "platform, supported_platforms, expected_result",
+ [
+ ["Linux", ["Linux"], True],
+ ["Linux", [], True],
+ ["Linux", None, True],
+ ["Windows", ["Linux"], False],
+ ],
+)
+def test_could_be_installed_depends_on_platform(
+ platform: str,
+ supported_platforms: Optional[List[str]],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test that installation could not be installed on unsupported platform."""
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.platform.system", MagicMock(return_value=platform)
+ )
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.all_paths_valid", MagicMock(return_value=True)
+ )
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+
+ installation = get_aiet_based_installation(
+ aiet_runner_mock,
+ supported_platforms=supported_platforms,
+ )
+ assert installation.could_be_installed == expected_result
+
+
+def test_get_corstone_installations() -> None:
+ """Test function get_corstone_installation."""
+ installs = get_corstone_installations()
+ assert len(installs) == 2
+ assert all(isinstance(install, AIETBasedInstallation) for install in installs)
+
+
+def test_aiet_based_installation_metadata_resolving() -> None:
+ """Test AIET based installation metadata resolving."""
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(aiet_runner_mock)
+
+ assert installation.name == "test_name"
+ assert installation.description == "test_description"
+
+ aiet_runner_mock.all_installed.return_value = False
+ assert installation.already_installed is False
+
+ assert installation.could_be_installed is True
+
+
+def test_aiet_based_installation_supported_install_types(tmp_path: Path) -> None:
+ """Test supported installation types."""
+ installation_no_download_artifact = get_aiet_based_installation()
+ assert installation_no_download_artifact.supports(DownloadAndInstall()) is False
+
+ installation_with_download_artifact = get_aiet_based_installation(
+ download_artifact=MagicMock()
+ )
+ assert installation_with_download_artifact.supports(DownloadAndInstall()) is True
+
+ path_checker_mock = MagicMock(return_value=BackendInfo(tmp_path))
+ installation_can_install_from_dir = get_aiet_based_installation(
+ path_checker=path_checker_mock
+ )
+ assert installation_can_install_from_dir.supports(InstallFromPath(tmp_path)) is True
+
+ any_installation = get_aiet_based_installation()
+ assert any_installation.supports("unknown_install_type") is False # type: ignore
+
+
+def test_aiet_based_installation_install_wrong_type() -> None:
+ """Test that operation should fail if wrong install type provided."""
+ with pytest.raises(Exception, match="Unable to install wrong_install_type"):
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(aiet_runner_mock)
+
+ installation.install("wrong_install_type") # type: ignore
+
+
+def test_aiet_based_installation_install_from_path(
+ tmp_path: Path, test_mlia_resources: Path
+) -> None:
+ """Test installation from the path."""
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ sample_app = test_mlia_resources / "sample_app"
+ sample_app.mkdir()
+
+ dist_dir = tmp_path / "dist"
+ dist_dir.mkdir()
+
+ path_checker_mock = MagicMock(return_value=BackendInfo(dist_dir))
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(
+ aiet_runner_mock=aiet_runner_mock,
+ path_checker=path_checker_mock,
+ apps_resources=[sample_app.name],
+ system_config="example_config.json",
+ )
+
+ assert installation.supports(InstallFromPath(dist_dir)) is True
+ installation.install(InstallFromPath(dist_dir))
+
+ aiet_runner_mock.install_system.assert_called_once()
+ aiet_runner_mock.install_application.assert_called_once_with(sample_app)
+
+
+@pytest.mark.parametrize("copy_source", [True, False])
+def test_aiet_based_installation_install_from_static_path(
+ tmp_path: Path, test_mlia_resources: Path, copy_source: bool
+) -> None:
+ """Test installation from the predefined path."""
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ custom_system_config = test_mlia_resources / "custom_config.json"
+ custom_system_config.touch()
+
+ sample_app = test_mlia_resources / "sample_app"
+ sample_app.mkdir()
+
+ predefined_location = tmp_path / "backend"
+ predefined_location.mkdir()
+
+ predefined_location_file = predefined_location / "file.txt"
+ predefined_location_file.touch()
+
+ predefined_location_dir = predefined_location / "folder"
+ predefined_location_dir.mkdir()
+ nested_file = predefined_location_dir / "nested_file.txt"
+ nested_file.touch()
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+
+ def check_install_dir(install_dir: Path) -> None:
+ """Check content of the install dir."""
+ assert install_dir.is_dir()
+ files = list(install_dir.iterdir())
+
+ if copy_source:
+ assert len(files) == 3
+ assert all(install_dir / item in files for item in ["file.txt", "folder"])
+ assert (install_dir / "folder/nested_file.txt").is_file()
+ else:
+ assert len(files) == 1
+
+ assert install_dir / "custom_config.json" in files
+
+ aiet_runner_mock.install_system.side_effect = check_install_dir
+
+ installation = get_aiet_based_installation(
+ aiet_runner_mock=aiet_runner_mock,
+ path_checker=StaticPathChecker(
+ predefined_location,
+ ["file.txt"],
+ copy_source=copy_source,
+ system_config=str(custom_system_config),
+ ),
+ apps_resources=[sample_app.name],
+ system_config="example_config.json",
+ )
+
+ assert installation.supports(InstallFromPath(predefined_location)) is True
+ installation.install(InstallFromPath(predefined_location))
+
+ aiet_runner_mock.install_system.assert_called_once()
+ aiet_runner_mock.install_application.assert_called_once_with(sample_app)
+
+
+def create_sample_fvp_archive(tmp_path: Path) -> Path:
+ """Create sample FVP tar archive."""
+ fvp_archive_dir = tmp_path / "archive"
+ fvp_archive_dir.mkdir()
+
+ sample_file = fvp_archive_dir / "sample.txt"
+ sample_file.write_text("Sample file")
+
+ sample_dir = fvp_archive_dir / "sample_dir"
+ sample_dir.mkdir()
+
+ fvp_archive = tmp_path / "archive.tgz"
+ with tarfile.open(fvp_archive, "w:gz") as fvp_archive_tar:
+ fvp_archive_tar.add(fvp_archive_dir, arcname=fvp_archive_dir.name)
+
+ return fvp_archive
+
+
+def test_aiet_based_installation_download_and_install(
+ test_mlia_resources: Path, tmp_path: Path
+) -> None:
+ """Test downloading and installation process."""
+ fvp_archive = create_sample_fvp_archive(tmp_path)
+
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ download_artifact_mock = MagicMock()
+ download_artifact_mock.download_to.return_value = fvp_archive
+
+ path_checker = PackagePathChecker(["archive/sample.txt"], "archive/sample_dir")
+
+ def installer(_eula_agreement: bool, dist_dir: Path) -> Path:
+ """Sample installer."""
+ return dist_dir
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(
+ aiet_runner_mock,
+ download_artifact=download_artifact_mock,
+ backend_installer=installer,
+ path_checker=path_checker,
+ system_config="example_config.json",
+ )
+
+ installation.install(DownloadAndInstall())
+
+ aiet_runner_mock.install_system.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "dir_content, expected_result",
+ [
+ [
+ ["models/", "file1.txt", "file2.txt"],
+ "models",
+ ],
+ [
+ ["file1.txt", "file2.txt"],
+ None,
+ ],
+ [
+ ["models/", "file2.txt"],
+ None,
+ ],
+ ],
+)
+def test_corstone_path_checker_valid_path(
+ tmp_path: Path, dir_content: List[str], expected_result: Optional[str]
+) -> None:
+ """Test Corstone path checker valid scenario."""
+ path_checker = PackagePathChecker(["file1.txt", "file2.txt"], "models")
+
+ for item in dir_content:
+ if item.endswith("/"):
+ item_dir = tmp_path / item
+ item_dir.mkdir()
+ else:
+ item_file = tmp_path / item
+ item_file.touch()
+
+ result = path_checker(tmp_path)
+ expected = (
+ None if expected_result is None else BackendInfo(tmp_path / expected_result)
+ )
+
+ assert result == expected
+
+
+@pytest.mark.parametrize("system_config", [None, "system_config"])
+@pytest.mark.parametrize("copy_source", [True, False])
+def test_static_path_checker(
+ tmp_path: Path, copy_source: bool, system_config: Optional[str]
+) -> None:
+ """Test static path checker."""
+ static_checker = StaticPathChecker(
+ tmp_path, [], copy_source=copy_source, system_config=system_config
+ )
+ assert static_checker(tmp_path) == BackendInfo(
+ tmp_path, copy_source=copy_source, system_config=system_config
+ )
+
+
+def test_static_path_checker_not_valid_path(tmp_path: Path) -> None:
+ """Test static path checker should return None if path is not valid."""
+ static_checker = StaticPathChecker(tmp_path, ["file.txt"])
+ assert static_checker(tmp_path / "backend") is None
+
+
+def test_static_path_checker_not_valid_structure(tmp_path: Path) -> None:
+ """Test static path checker should return None if files are missing."""
+ static_checker = StaticPathChecker(tmp_path, ["file.txt"])
+ assert static_checker(tmp_path) is None
+
+ missing_file = tmp_path / "file.txt"
+ missing_file.touch()
+
+ assert static_checker(tmp_path) == BackendInfo(tmp_path, copy_source=False)
+
+
+def test_compound_path_checker(tmp_path: Path) -> None:
+ """Test compound path checker."""
+ path_checker_path_valid_path = MagicMock(return_value=BackendInfo(tmp_path))
+ path_checker_path_not_valid_path = MagicMock(return_value=None)
+
+ checker = CompoundPathChecker(
+ path_checker_path_valid_path, path_checker_path_not_valid_path
+ )
+ assert checker(tmp_path) == BackendInfo(tmp_path)
+
+ checker = CompoundPathChecker(path_checker_path_not_valid_path)
+ assert checker(tmp_path) is None
+
+
+@pytest.mark.parametrize(
+ "eula_agreement, expected_command",
+ [
+ [
+ True,
+ [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ "corstone-300",
+ ],
+ ],
+ [
+ False,
+ [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ "corstone-300",
+ "--nointeractive",
+ "--i-agree-to-the-contained-eula",
+ ],
+ ],
+ ],
+)
+def test_corstone_300_installer(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ eula_agreement: bool,
+ expected_command: List[str],
+) -> None:
+ """Test Corstone-300 installer."""
+ command_mock = MagicMock()
+
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.subprocess.check_call", command_mock
+ )
+ installer = Corstone300Installer()
+ result = installer(eula_agreement, tmp_path)
+
+ command_mock.assert_called_once_with(expected_command)
+ assert result == tmp_path / "corstone-300"
diff --git a/tests/mlia/test_tools_vela_wrapper.py b/tests/mlia/test_tools_vela_wrapper.py
new file mode 100644
index 0000000..875d2ff
--- /dev/null
+++ b/tests/mlia/test_tools_vela_wrapper.py
@@ -0,0 +1,285 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module tools/vela_wrapper."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+from ethosu.vela.compiler_driver import TensorAllocator
+from ethosu.vela.scheduler import OptimizationStrategy
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.tools.vela_wrapper import estimate_performance
+from mlia.tools.vela_wrapper import generate_supported_operators_report
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.tools.vela_wrapper import optimize_model
+from mlia.tools.vela_wrapper import OptimizedModel
+from mlia.tools.vela_wrapper import PerformanceMetrics
+from mlia.tools.vela_wrapper import supported_operators
+from mlia.tools.vela_wrapper import VelaCompiler
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.proc import working_directory
+
+
+def test_default_vela_compiler() -> None:
+ """Test default Vela compiler instance."""
+ default_compiler_options = VelaCompilerOptions(accelerator_config="ethos-u55-256")
+ default_compiler = VelaCompiler(default_compiler_options)
+
+ assert default_compiler.config_files is None
+ assert default_compiler.system_config == "internal-default"
+ assert default_compiler.memory_mode == "internal-default"
+ assert default_compiler.accelerator_config == "ethos-u55-256"
+ assert default_compiler.max_block_dependency == 3
+ assert default_compiler.arena_cache_size is None
+ assert default_compiler.tensor_allocator == TensorAllocator.HillClimb
+ assert default_compiler.cpu_tensor_alignment == 16
+ assert default_compiler.optimization_strategy == OptimizationStrategy.Performance
+ assert default_compiler.output_dir is None
+
+ assert default_compiler.get_config() == {
+ "accelerator_config": "ethos-u55-256",
+ "system_config": "internal-default",
+ "core_clock": 500000000.0,
+ "axi0_port": "Sram",
+ "axi1_port": "OffChipFlash",
+ "memory_mode": "internal-default",
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": 4294967296,
+ "permanent_storage_mem_area": "OffChipFlash",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": 32,
+ "read_latency": 32,
+ "write_latency": 32,
+ },
+ "Dram": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OffChipFlash": {
+ "clock_scales": 0.125,
+ "burst_length": 128,
+ "read_latency": 64,
+ "write_latency": 64,
+ },
+ },
+ }
+
+
+def test_vela_compiler_with_parameters(test_resources_path: Path) -> None:
+ """Test creation of Vela compiler instance with non-default params."""
+ vela_ini_path = str(test_resources_path / "vela/sample_vela.ini")
+
+ compiler_options = VelaCompilerOptions(
+ config_files=vela_ini_path,
+ system_config="Ethos_U65_High_End",
+ memory_mode="Shared_Sram",
+ accelerator_config="ethos-u65-256",
+ max_block_dependency=1,
+ arena_cache_size=10,
+ tensor_allocator="Greedy",
+ cpu_tensor_alignment=4,
+ optimization_strategy="Size",
+ output_dir="output",
+ )
+ compiler = VelaCompiler(compiler_options)
+
+ assert compiler.config_files == vela_ini_path
+ assert compiler.system_config == "Ethos_U65_High_End"
+ assert compiler.memory_mode == "Shared_Sram"
+ assert compiler.accelerator_config == "ethos-u65-256"
+ assert compiler.max_block_dependency == 1
+ assert compiler.arena_cache_size == 10
+ assert compiler.tensor_allocator == TensorAllocator.Greedy
+ assert compiler.cpu_tensor_alignment == 4
+ assert compiler.optimization_strategy == OptimizationStrategy.Size
+ assert compiler.output_dir == "output"
+
+ assert compiler.get_config() == {
+ "accelerator_config": "ethos-u65-256",
+ "system_config": "Ethos_U65_High_End",
+ "core_clock": 1000000000.0,
+ "axi0_port": "Sram",
+ "axi1_port": "Dram",
+ "memory_mode": "Shared_Sram",
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": 10,
+ "permanent_storage_mem_area": "Dram",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": 32,
+ "read_latency": 32,
+ "write_latency": 32,
+ },
+ "Dram": {
+ "clock_scales": 0.234375,
+ "burst_length": 128,
+ "read_latency": 500,
+ "write_latency": 250,
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OffChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ },
+ }
+
+
+def test_compile_model(test_tflite_model: Path) -> None:
+ """Test model optimization."""
+ compiler = VelaCompiler(EthosUConfiguration("ethos-u55-256").compiler_options)
+
+ optimized_model = compiler.compile_model(test_tflite_model)
+ assert isinstance(optimized_model, OptimizedModel)
+
+
+def test_optimize_model(tmp_path: Path, test_tflite_model: Path) -> None:
+ """Test model optimization and saving into file."""
+ tmp_file = tmp_path / "temp.tflite"
+
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(test_tflite_model, device.compiler_options, tmp_file.absolute())
+
+ assert tmp_file.is_file()
+ assert tmp_file.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "model, expected_ops",
+ [
+ (
+ "test_model.tflite",
+ Operators(
+ ops=[
+ Operator(
+ name="sequential/conv1/Relu;sequential/conv1/BiasAdd;"
+ "sequential/conv2/Conv2D;sequential/conv1/Conv2D",
+ op_type="CONV_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/conv2/Relu;sequential/conv2/BiasAdd;"
+ "sequential/conv2/Conv2D",
+ op_type="CONV_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/max_pooling2d/MaxPool",
+ op_type="MAX_POOL_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/flatten/Reshape",
+ op_type="RESHAPE",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="Identity",
+ op_type="FULLY_CONNECTED",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ ]
+ ),
+ )
+ ],
+)
+def test_operators(test_models_path: Path, model: str, expected_ops: Operators) -> None:
+ """Test operators function."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ operators = supported_operators(test_models_path / model, device.compiler_options)
+ for expected, actual in zip(expected_ops.ops, operators.ops):
+ # do not compare names as they could be different on each model generation
+ assert expected.op_type == actual.op_type
+ assert expected.run_on_npu == actual.run_on_npu
+
+
+def test_estimate_performance(test_tflite_model: Path) -> None:
+ """Test getting performance estimations."""
+ device = EthosUConfiguration("ethos-u55-256")
+ perf_metrics = estimate_performance(test_tflite_model, device.compiler_options)
+
+ assert isinstance(perf_metrics, PerformanceMetrics)
+
+
+def test_estimate_performance_already_optimized(
+ tmp_path: Path, test_tflite_model: Path
+) -> None:
+ """Test that performance estimation should fail for already optimized model."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ optimized_model_path = tmp_path / "optimized_model.tflite"
+
+ optimize_model(test_tflite_model, device.compiler_options, optimized_model_path)
+
+ with pytest.raises(
+ Exception, match="Unable to estimate performance for the given optimized model"
+ ):
+ estimate_performance(optimized_model_path, device.compiler_options)
+
+
+def test_generate_supported_operators_report(tmp_path: Path) -> None:
+ """Test generating supported operators report."""
+ with working_directory(tmp_path):
+ generate_supported_operators_report()
+
+ md_file = tmp_path / "SUPPORTED_OPS.md"
+ assert md_file.is_file()
+ assert md_file.stat().st_size > 0
+
+
+def test_read_invalid_model(test_tflite_invalid_model: Path) -> None:
+ """Test that reading invalid model should fail with exception."""
+ with pytest.raises(
+ Exception, match=f"Unable to read model {test_tflite_invalid_model}"
+ ):
+ device = EthosUConfiguration("ethos-u55-256")
+ estimate_performance(test_tflite_invalid_model, device.compiler_options)
+
+
+def test_compile_invalid_model(
+ test_tflite_model: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test that if model could not be compiled then correct exception raised."""
+ mock_compiler = MagicMock()
+ mock_compiler.side_effect = Exception("Bad model!")
+
+ monkeypatch.setattr("mlia.tools.vela_wrapper.compiler_driver", mock_compiler)
+
+ model_path = tmp_path / "optimized_model.tflite"
+ with pytest.raises(
+ Exception, match="Model could not be optimized with Vela compiler"
+ ):
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(test_tflite_model, device.compiler_options, model_path)
+
+ assert not model_path.exists()
diff --git a/tests/mlia/test_utils_console.py b/tests/mlia/test_utils_console.py
new file mode 100644
index 0000000..36975f8
--- /dev/null
+++ b/tests/mlia/test_utils_console.py
@@ -0,0 +1,100 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for console utility functions."""
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.utils.console import apply_style
+from mlia.utils.console import create_section_header
+from mlia.utils.console import produce_table
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "rows, headers, table_style, expected_result",
+ [
+ [[], [], "no_borders", ""],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "default",
+ """
+┌───────┬───────┬───────┐
+│ Col 1 │ Col 2 │ Col 3 │
+╞═══════╪═══════╪═══════╡
+│ 1 │ 2 │ 3 │
+└───────┴───────┴───────┘
+""".strip(),
+ ],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "nested",
+ "Col 1 Col 2 Col 3 \n \n1 2 3",
+ ],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "no_borders",
+ " Col 1 Col 2 Col 3 \n 1 2 3",
+ ],
+ ],
+)
+def test_produce_table(
+ rows: Iterable, headers: Optional[List[str]], table_style: str, expected_result: str
+) -> None:
+ """Test produce_table function."""
+ result = produce_table(rows, headers, table_style)
+ assert remove_ascii_codes(result) == expected_result
+
+
+def test_produce_table_unknown_style() -> None:
+ """Test that function should fail if unknown style provided."""
+ with pytest.raises(Exception, match="Unsupported table style unknown_style"):
+ produce_table([["1", "2", "3"]], [], "unknown_style")
+
+
+@pytest.mark.parametrize(
+ "value, expected_result",
+ [
+ ["some text", "some text"],
+ ["\033[32msome text\033[0m", "some text"],
+ ],
+)
+def test_remove_ascii_codes(value: str, expected_result: str) -> None:
+ """Test remove_ascii_codes function."""
+ assert remove_ascii_codes(value) == expected_result
+
+
+def test_apply_style() -> None:
+ """Test function apply_style."""
+ assert apply_style("some text", "green") == "[green]some text"
+
+
+@pytest.mark.parametrize(
+ "section_header, expected_result",
+ [
+ [
+ "Section header",
+ "\n--- Section header -------------------------------"
+ "------------------------------\n",
+ ],
+ [
+ "",
+ f"\n{'-' * 80}\n",
+ ],
+ ],
+)
+def test_create_section_header(section_header: str, expected_result: str) -> None:
+ """Test function test_create_section."""
+ assert create_section_header(section_header) == expected_result
+
+
+def test_create_section_header_too_long_value() -> None:
+ """Test that header could not be created for the too long section names."""
+ section_name = "section name" * 100
+ with pytest.raises(ValueError, match="Section name too long"):
+ create_section_header(section_name)
diff --git a/tests/mlia/test_utils_download.py b/tests/mlia/test_utils_download.py
new file mode 100644
index 0000000..4f8e2dc
--- /dev/null
+++ b/tests/mlia/test_utils_download.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for download functionality."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Iterable
+from typing import Optional
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+import requests
+
+from mlia.utils.download import download
+from mlia.utils.download import DownloadArtifact
+
+
+def response_mock(
+ content_length: Optional[str], content_chunks: Iterable[bytes]
+) -> MagicMock:
+ """Mock response object."""
+ mock = MagicMock(spec=requests.Response)
+ mock.__enter__.return_value = mock
+
+ type(mock).headers = PropertyMock(return_value={"Content-Length": content_length})
+ mock.iter_content.return_value = content_chunks
+
+ return mock
+
+
+@pytest.mark.parametrize("show_progress", [True, False])
+@pytest.mark.parametrize(
+ "content_length, content_chunks, label",
+ [
+ [
+ "5",
+ [bytes(range(5))],
+ "Downloading artifact",
+ ],
+ [
+ "10",
+ [bytes(range(5)), bytes(range(5))],
+ None,
+ ],
+ [
+ None,
+ [bytes(range(5))],
+ "Downlading no size",
+ ],
+ [
+ "abc",
+ [bytes(range(5))],
+ "Downloading wrong size",
+ ],
+ ],
+)
+def test_download(
+ show_progress: bool,
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ content_length: Optional[str],
+ content_chunks: Iterable[bytes],
+ label: Optional[str],
+) -> None:
+ """Test function download."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock(content_length, content_chunks)),
+ )
+
+ dest = tmp_path / "sample.bin"
+ download("some_url", dest, show_progress=show_progress, label=label)
+
+ assert dest.is_file()
+ assert dest.read_bytes() == bytes(
+ byte for chunk in content_chunks for byte in chunk
+ )
+
+
+@pytest.mark.parametrize(
+ "content_length, content_chunks, sha256_hash, expected_error",
+ [
+ [
+ "10",
+ [bytes(range(10))],
+ "1f825aa2f0020ef7cf91dfa30da4668d791c5d4824fc8e41354b89ec05795ab3",
+ does_not_raise(),
+ ],
+ [
+ "10",
+ [bytes(range(10))],
+ "bad_hash",
+ pytest.raises(ValueError, match="Digests do not match"),
+ ],
+ ],
+)
+def test_download_artifact_download_to(
+ monkeypatch: pytest.MonkeyPatch,
+ content_length: Optional[str],
+ content_chunks: Iterable[bytes],
+ sha256_hash: str,
+ expected_error: Any,
+ tmp_path: Path,
+) -> None:
+ """Test artifact downloading."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock(content_length, content_chunks)),
+ )
+
+ with expected_error:
+ artifact = DownloadArtifact(
+ "test_artifact",
+ "some_url",
+ "artifact_filename",
+ "1.0",
+ sha256_hash,
+ )
+
+ dest = artifact.download_to(tmp_path)
+ assert isinstance(dest, Path)
+ assert dest.name == "artifact_filename"
+
+
+def test_download_artifact_unable_to_overwrite(
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test that download process cannot overwrite file."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock("10", [bytes(range(10))])),
+ )
+
+ artifact = DownloadArtifact(
+ "test_artifact",
+ "some_url",
+ "artifact_filename",
+ "1.0",
+ "sha256_hash",
+ )
+
+ existing_file = tmp_path / "artifact_filename"
+ existing_file.touch()
+
+ with pytest.raises(ValueError, match=f"{existing_file} already exists"):
+ artifact.download_to(tmp_path)
diff --git a/tests/mlia/test_utils_filesystem.py b/tests/mlia/test_utils_filesystem.py
new file mode 100644
index 0000000..4d8d955
--- /dev/null
+++ b/tests/mlia/test_utils_filesystem.py
@@ -0,0 +1,166 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the filesystem module."""
+import contextlib
+import json
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.filesystem import all_files_exist
+from mlia.utils.filesystem import all_paths_valid
+from mlia.utils.filesystem import copy_all
+from mlia.utils.filesystem import get_mlia_resources
+from mlia.utils.filesystem import get_profile
+from mlia.utils.filesystem import get_profiles_data
+from mlia.utils.filesystem import get_profiles_file
+from mlia.utils.filesystem import get_supported_profile_names
+from mlia.utils.filesystem import get_vela_config
+from mlia.utils.filesystem import sha256
+from mlia.utils.filesystem import temp_directory
+from mlia.utils.filesystem import temp_file
+
+
+def test_get_mlia_resources() -> None:
+ """Test resources getter."""
+ assert get_mlia_resources().is_dir()
+
+
+def test_get_vela_config() -> None:
+ """Test Vela config files getter."""
+ assert get_vela_config().is_file()
+ assert get_vela_config().name == "vela.ini"
+
+
+def test_profiles_file() -> None:
+ """Test profiles file getter."""
+ assert get_profiles_file().is_file()
+ assert get_profiles_file().name == "profiles.json"
+
+
+def test_profiles_data() -> None:
+ """Test profiles data getter."""
+ assert list(get_profiles_data().keys()) == [
+ "ethos-u55-256",
+ "ethos-u55-128",
+ "ethos-u65-512",
+ ]
+
+
+def test_profiles_data_wrong_format(
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test if profile data has wrong format."""
+ wrong_profile_data = tmp_path / "bad.json"
+ with open(wrong_profile_data, "w", encoding="utf-8") as file:
+ json.dump([], file)
+
+ monkeypatch.setattr(
+ "mlia.utils.filesystem.get_profiles_file",
+ MagicMock(return_value=wrong_profile_data),
+ )
+
+ with pytest.raises(Exception, match="Profiles data format is not valid"):
+ get_profiles_data()
+
+
+def test_get_supported_profile_names() -> None:
+ """Test profile names getter."""
+ assert list(get_supported_profile_names()) == [
+ "ethos-u55-256",
+ "ethos-u55-128",
+ "ethos-u65-512",
+ ]
+
+
+def test_get_profile() -> None:
+ """Test getting profile data."""
+ assert get_profile("ethos-u55-256") == {
+ "target": "ethos-u55",
+ "mac": 256,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram",
+ }
+
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ get_profile("unknown")
+
+
+@pytest.mark.parametrize("raise_exception", [True, False])
+def test_temp_file(raise_exception: bool) -> None:
+ """Test temp_file context manager."""
+ with contextlib.suppress(Exception):
+ with temp_file() as tmp_path:
+ assert tmp_path.is_file()
+
+ if raise_exception:
+ raise Exception("Error!")
+
+ assert not tmp_path.exists()
+
+
+def test_sha256(tmp_path: Path) -> None:
+ """Test getting sha256 hash."""
+ sample = tmp_path / "sample.txt"
+
+ with open(sample, "w", encoding="utf-8") as file:
+ file.write("123")
+
+ expected_hash = "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"
+ assert sha256(sample) == expected_hash
+
+
+def test_temp_dir_context_manager() -> None:
+ """Test context manager for temporary directories."""
+ with temp_directory() as tmpdir:
+ assert isinstance(tmpdir, Path)
+ assert tmpdir.is_dir()
+
+ assert not tmpdir.exists()
+
+
+def test_all_files_exist(tmp_path: Path) -> None:
+ """Test function all_files_exist."""
+ sample1 = tmp_path / "sample1.txt"
+ sample1.touch()
+
+ sample2 = tmp_path / "sample2.txt"
+ sample2.touch()
+
+ sample3 = tmp_path / "sample3.txt"
+
+ assert all_files_exist([sample1, sample2]) is True
+ assert all_files_exist([sample1, sample2, sample3]) is False
+
+
+def test_all_paths_valid(tmp_path: Path) -> None:
+ """Test function all_paths_valid."""
+ sample = tmp_path / "sample.txt"
+ sample.touch()
+
+ sample_dir = tmp_path / "sample_dir"
+ sample_dir.mkdir()
+
+ unknown = tmp_path / "unknown.txt"
+
+ assert all_paths_valid([sample, sample_dir]) is True
+ assert all_paths_valid([sample, sample_dir, unknown]) is False
+
+
+def test_copy_all(tmp_path: Path) -> None:
+ """Test function copy_all."""
+ sample = tmp_path / "sample1.txt"
+ sample.touch()
+
+ sample_dir = tmp_path / "sample_dir"
+ sample_dir.mkdir()
+
+ sample_nested_file = sample_dir / "sample_nested.txt"
+ sample_nested_file.touch()
+
+ dest_dir = tmp_path / "dest"
+ copy_all(sample, sample_dir, dest=dest_dir)
+
+ assert (dest_dir / sample.name).is_file()
+ assert (dest_dir / sample_nested_file.name).is_file()
diff --git a/tests/mlia/test_utils_logging.py b/tests/mlia/test_utils_logging.py
new file mode 100644
index 0000000..75ebceb
--- /dev/null
+++ b/tests/mlia/test_utils_logging.py
@@ -0,0 +1,63 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the logging utility functions."""
+import logging
+import sys
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Optional
+
+import pytest
+
+from mlia.cli.logging import create_log_handler
+
+
+@pytest.mark.parametrize(
+ "file_path, stream, log_filter, delay, expected_error, expected_class",
+ [
+ (
+ "test.log",
+ None,
+ None,
+ True,
+ does_not_raise(),
+ logging.FileHandler,
+ ),
+ (
+ None,
+ sys.stdout,
+ None,
+ None,
+ does_not_raise(),
+ logging.StreamHandler,
+ ),
+ (
+ None,
+ None,
+ None,
+ None,
+ pytest.raises(Exception, match="Unable to create logging handler"),
+ None,
+ ),
+ ],
+)
+def test_create_log_handler(
+ file_path: Optional[Path],
+ stream: Optional[Any],
+ log_filter: Optional[logging.Filter],
+ delay: bool,
+ expected_error: Any,
+ expected_class: type,
+) -> None:
+ """Test function test_create_log_handler."""
+ with expected_error:
+ handler = create_log_handler(
+ file_path=file_path,
+ stream=stream,
+ log_level=logging.INFO,
+ log_format="%(name)s - %(message)s",
+ log_filter=log_filter,
+ delay=delay,
+ )
+ assert isinstance(handler, expected_class)
diff --git a/tests/mlia/test_utils_misc.py b/tests/mlia/test_utils_misc.py
new file mode 100644
index 0000000..011d09e
--- /dev/null
+++ b/tests/mlia/test_utils_misc.py
@@ -0,0 +1,25 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for misc util functions."""
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.misc import yes
+
+
+@pytest.mark.parametrize(
+ "response, expected_result",
+ [
+ ["Y", True],
+ ["y", True],
+ ["N", False],
+ ["n", False],
+ ],
+)
+def test_yes(
+ monkeypatch: pytest.MonkeyPatch, expected_result: bool, response: str
+) -> None:
+ """Test yes function."""
+ monkeypatch.setattr("builtins.input", MagicMock(return_value=response))
+ assert yes("some_prompt") == expected_result
diff --git a/tests/mlia/test_utils_proc.py b/tests/mlia/test_utils_proc.py
new file mode 100644
index 0000000..8316ca5
--- /dev/null
+++ b/tests/mlia/test_utils_proc.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module utils/proc."""
+import signal
+import subprocess
+import time
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.proc import CommandExecutor
+from mlia.utils.proc import working_directory
+
+
+class TestCommandExecutor:
+ """Tests for class CommandExecutor."""
+
+ @staticmethod
+ def test_execute() -> None:
+ """Test command execution."""
+ executor = CommandExecutor()
+
+ retcode, stdout, stderr = executor.execute(["echo", "hello world!"])
+ assert retcode == 0
+ assert stdout.decode().strip() == "hello world!"
+ assert stderr.decode() == ""
+
+ @staticmethod
+ def test_submit() -> None:
+ """Test command submittion."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["sleep", "10"])
+ assert running_command.is_alive() is True
+ assert running_command.exit_code() is None
+
+ running_command.kill()
+ for _ in range(3):
+ time.sleep(0.5)
+ if not running_command.is_alive():
+ break
+
+ assert running_command.is_alive() is False
+ assert running_command.exit_code() == -9
+
+ with pytest.raises(subprocess.CalledProcessError):
+ executor.execute(["sleep", "-1"])
+
+ @staticmethod
+ @pytest.mark.parametrize("wait", [True, False])
+ def test_stop(wait: bool) -> None:
+ """Test command termination."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["sleep", "10"])
+ running_command.stop(wait=wait)
+
+ if wait:
+ assert running_command.is_alive() is False
+
+ @staticmethod
+ def test_unable_to_stop(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test when command could not be stopped."""
+ running_command_mock = MagicMock()
+ running_command_mock.poll.return_value = None
+
+ monkeypatch.setattr(
+ "mlia.utils.proc.subprocess.Popen",
+ MagicMock(return_value=running_command_mock),
+ )
+
+ with pytest.raises(Exception, match="Unable to stop running command"):
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+
+ running_command.stop(num_of_attempts=1, interval=0.1)
+
+ running_command_mock.send_signal.assert_called_once_with(signal.SIGINT)
+
+ @staticmethod
+ def test_stop_after_several_attempts(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test when command could be stopped after several attempts."""
+ running_command_mock = MagicMock()
+ running_command_mock.poll.side_effect = [None, 0]
+
+ monkeypatch.setattr(
+ "mlia.utils.proc.subprocess.Popen",
+ MagicMock(return_value=running_command_mock),
+ )
+
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+
+ running_command.stop(num_of_attempts=1, interval=0.1)
+ running_command_mock.send_signal.assert_called_once_with(signal.SIGINT)
+
+ @staticmethod
+ def test_send_signal() -> None:
+ """Test sending signal."""
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+ running_command.send_signal(signal.SIGINT)
+
+ # wait a bit for a signal processing
+ time.sleep(1)
+
+ assert running_command.is_alive() is False
+ assert running_command.exit_code() == -2
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "redirect_output, expected_output", [[True, "hello\n"], [False, ""]]
+ )
+ def test_wait(
+ capsys: pytest.CaptureFixture, redirect_output: bool, expected_output: str
+ ) -> None:
+ """Test wait completion functionality."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["echo", "hello"])
+ running_command.wait(redirect_output=redirect_output)
+
+ out, _ = capsys.readouterr()
+ assert out == expected_output
+
+
+@pytest.mark.parametrize(
+ "should_exist, create_dir",
+ [
+ [True, False],
+ [False, True],
+ ],
+)
+def test_working_directory_context_manager(
+ tmp_path: Path, should_exist: bool, create_dir: bool
+) -> None:
+ """Test working_directory context manager."""
+ prev_wd = Path.cwd()
+
+ working_dir = tmp_path / "work_dir"
+ if should_exist:
+ working_dir.mkdir()
+
+ with working_directory(working_dir, create_dir=create_dir) as current_working_dir:
+ assert current_working_dir.is_dir()
+ assert Path.cwd() == current_working_dir
+
+ assert Path.cwd() == prev_wd
diff --git a/tests/mlia/test_utils_types.py b/tests/mlia/test_utils_types.py
new file mode 100644
index 0000000..4909efe
--- /dev/null
+++ b/tests/mlia/test_utils_types.py
@@ -0,0 +1,77 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the types related utility functions."""
+from typing import Any
+from typing import Iterable
+from typing import Optional
+
+import pytest
+
+from mlia.utils.types import is_list_of
+from mlia.utils.types import is_number
+from mlia.utils.types import only_one_selected
+from mlia.utils.types import parse_int
+
+
+@pytest.mark.parametrize(
+ "value, expected_result",
+ [
+ ["", False],
+ ["abc", False],
+ ["123", True],
+ ["123.1", True],
+ ["-123", True],
+ ["-123.1", True],
+ ["0", True],
+ ["1.e10", True],
+ ],
+)
+def test_is_number(value: str, expected_result: bool) -> None:
+ """Test function is_number."""
+ assert is_number(value) == expected_result
+
+
+@pytest.mark.parametrize(
+ "data, cls, elem_num, expected_result",
+ [
+ [(1, 2), int, 2, True],
+ [[1, 2], int, 2, True],
+ [[1, 2], int, 3, False],
+ [["1", "2", "3"], str, None, True],
+ [["1", "2", "3"], int, None, False],
+ ],
+)
+def test_is_list(
+ data: Any, cls: type, elem_num: Optional[int], expected_result: bool
+) -> None:
+ """Test function is_list."""
+ assert is_list_of(data, cls, elem_num) == expected_result
+
+
+@pytest.mark.parametrize(
+ "options, expected_result",
+ [
+ [[True], True],
+ [[False], False],
+ [[True, True, False, False], False],
+ ],
+)
+def test_only_one_selected(options: Iterable[bool], expected_result: bool) -> None:
+ """Test function only_one_selected."""
+ assert only_one_selected(*options) == expected_result
+
+
+@pytest.mark.parametrize(
+ "value, default, expected_int",
+ [
+ ["1", None, 1],
+ ["abc", 123, 123],
+ [None, None, None],
+ [None, 11, 11],
+ ],
+)
+def test_parse_int(
+ value: Any, default: Optional[int], expected_int: Optional[int]
+) -> None:
+ """Test function parse_int."""
+ assert parse_int(value, default) == expected_int
diff --git a/tests/mlia/utils/__init__.py b/tests/mlia/utils/__init__.py
new file mode 100644
index 0000000..27166ef
--- /dev/null
+++ b/tests/mlia/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test utils module."""
diff --git a/tests/mlia/utils/common.py b/tests/mlia/utils/common.py
new file mode 100644
index 0000000..4313cde
--- /dev/null
+++ b/tests/mlia/utils/common.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common test utils module."""
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+
+
+def get_dataset() -> Tuple[np.array, np.array]:
+ """Return sample dataset."""
+ mnist = tf.keras.datasets.mnist
+ (x_train, y_train), _ = mnist.load_data()
+ x_train = x_train / 255.0
+
+ # Use subset of 60000 examples to keep unit test speed fast.
+ x_train = x_train[0:1]
+ y_train = y_train[0:1]
+
+ return x_train, y_train
+
+
+def train_model(model: tf.keras.Model) -> None:
+ """Train model using sample dataset."""
+ num_epochs = 1
+
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ x_train, y_train = get_dataset()
+
+ model.fit(x_train, y_train, epochs=num_epochs)
diff --git a/tests/mlia/utils/logging.py b/tests/mlia/utils/logging.py
new file mode 100644
index 0000000..d223fb2
--- /dev/null
+++ b/tests/mlia/utils/logging.py
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils for logging."""
+import logging
+
+
+def clear_loggers() -> None:
+ """Close the log handlers."""
+ for _, logger in logging.Logger.manager.loggerDict.items():
+ if not isinstance(logger, logging.PlaceHolder):
+ for handler in logger.handlers:
+ handler.close()
+ logger.removeHandler(handler)