aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2022-11-29 13:29:04 +0000
committerRaul Farkas <raul.farkas@arm.com>2023-01-10 10:46:07 +0000
commit5800fc990ed1e36ce7d06670f911fbb12a0ec771 (patch)
tree294605295cd2624ba63e6ad3df335a2a4b2700ab /tests
parentdcd0bd31985c27e1d07333351b26cf8ad12ad1fd (diff)
downloadmlia-5800fc990ed1e36ce7d06670f911fbb12a0ec771.tar.gz
MLIA-650 Implement new CLI changes
Breaking change in the CLI and API: Sub-commands "optimization", "operators", and "performance" were replaced by "check", which incorporates compatibility and performance checks, and "optimize" which is used for optimization. "get_advice" API was adapted to these CLI changes. API changes: * Remove previous advice category "all" that would perform all three operations (when possible). Replace them with the ability to pass a set of the advice categories. * Update api.get_advice method docstring to reflect new changes. * Set default advice category to COMPATIBILITY * Update core.common.AdviceCategory by changing the "OPERATORS" advice category to "COMPATIBILITY" and removing "ALL" enum type. Update all subsequent methods that previously used "OPERATORS" to use "COMPATIBILITY". * Update core.context.ExecutionContext to have "COMPATIBILITY" as default advice_category instead of "ALL". * Remove api.generate_supported_operators_report and all related functions from cli.commands, cli.helpers, cli.main, cli.options, core.helpers * Update tests to reflect new API changes. CLI changes: * Update README.md to contain information on the new CLI * Remove the ability to generate supported operators support from MLIA CLI * Replace `mlia ops` and `mlia perf` with the new `mlia check` command that can be used to perform both operations. * Replace `mlia opt` with the new `mlia optimize` command. * Replace `--evaluate-on` flag with `--backend` flag * Replace `--verbose` flag with `--debug` flag (no behaviour change). * Remove the ability for the user to select MLIA working directory. Create and use a temporary directory in /temp instead. * Change behaviour of `--output` flag to not format the content automatically based on file extension anymore. Instead it will simply redirect to a file. * Add the `--json` flag to specfy that the format of the output should be json. * Add command validators that are used to validate inter-dependent flags (e.g. backend validation based on target_profile). * Add support for selecting built-in backends for both `check` and `optimize` commands. * Add new unit tests and update old ones to test the new CLI changes. * Update RELEASES.md * Update copyright notice Change-Id: Ia6340797c7bee3acbbd26601950e5a16ad5602db
Diffstat (limited to 'tests')
-rw-r--r--tests/test_api.py98
-rw-r--r--tests/test_backend_config.py12
-rw-r--r--tests/test_backend_registry.py8
-rw-r--r--tests/test_cli_command_validators.py167
-rw-r--r--tests/test_cli_commands.py97
-rw-r--r--tests/test_cli_config.py8
-rw-r--r--tests/test_cli_helpers.py62
-rw-r--r--tests/test_cli_main.py228
-rw-r--r--tests/test_cli_options.py179
-rw-r--r--tests/test_core_advice_generation.py10
-rw-r--r--tests/test_core_context.py46
-rw-r--r--tests/test_core_helpers.py3
-rw-r--r--tests/test_core_mixins.py6
-rw-r--r--tests/test_core_reporting.py22
-rw-r--r--tests/test_target_config.py6
-rw-r--r--tests/test_target_cortex_a_advice_generation.py18
-rw-r--r--tests/test_target_ethos_u_advice_generation.py70
-rw-r--r--tests/test_target_registry.py12
-rw-r--r--tests/test_target_tosa_advice_generation.py8
19 files changed, 584 insertions, 476 deletions
diff --git a/tests/test_api.py b/tests/test_api.py
index fbc558b..0bbc3ae 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,15 +1,13 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the API functions."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
-from unittest.mock import patch
import pytest
-from mlia.api import generate_supported_operators_report
from mlia.api import get_advice
from mlia.api import get_advisor
from mlia.core.common import AdviceCategory
@@ -22,63 +20,68 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
with pytest.raises(Exception, match="Target profile is not provided"):
- get_advice(None, test_keras_model, "all") # type: ignore
+ get_advice(None, test_keras_model, {"compatibility"}) # type: ignore
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
"""Test getting advice when wrong advice category provided."""
with pytest.raises(Exception, match="Invalid advice category unknown"):
- get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore
+ get_advice("ethos-u55-256", test_keras_model, {"unknown"})
@pytest.mark.parametrize(
"category, context, expected_category",
[
[
- "all",
+ {"compatibility", "optimization"},
None,
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "optimization",
+ {"optimization"},
None,
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
],
[
- "operators",
+ {"compatibility"},
None,
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
],
[
- "performance",
+ {"performance"},
None,
- AdviceCategory.PERFORMANCE,
+ {AdviceCategory.PERFORMANCE},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
- ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
- AdviceCategory.ALL,
+ {"compatibility", "optimization"},
+ ExecutionContext(
+ advice_category={
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.OPTIMIZATION,
+ }
+ ),
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(config_parameters={"param": "value"}),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(event_handlers=[MagicMock()]),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
],
)
def test_get_advice(
monkeypatch: pytest.MonkeyPatch,
- category: str,
+ category: set[str],
context: ExecutionContext,
expected_category: AdviceCategory,
test_keras_model: Path,
@@ -90,7 +93,7 @@ def test_get_advice(
get_advice(
"ethos-u55-256",
test_keras_model,
- category, # type: ignore
+ category,
context=context,
)
@@ -111,50 +114,3 @@ def test_get_advisor(
tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model))
assert isinstance(tosa_advisor, TOSAInferenceAdvisor)
-
-
-@pytest.mark.parametrize(
- ["target_profile", "required_calls", "exception_msg"],
- [
- [
- "ethos-u55-128",
- "mlia.target.ethos_u.operators.generate_supported_operators_report",
- None,
- ],
- [
- "ethos-u65-256",
- "mlia.target.ethos_u.operators.generate_supported_operators_report",
- None,
- ],
- [
- "tosa",
- None,
- "Generating a supported operators report is not "
- "currently supported with TOSA target profile.",
- ],
- [
- "cortex-a",
- None,
- "Generating a supported operators report is not "
- "currently supported with Cortex-A target profile.",
- ],
- [
- "Unknown",
- None,
- "Unable to find target profile Unknown",
- ],
- ],
-)
-def test_supported_ops_report_generator(
- target_profile: str, required_calls: str | None, exception_msg: str | None
-) -> None:
- """Test supported operators report generator with different target profiles."""
- if exception_msg:
- with pytest.raises(Exception) as exc:
- generate_supported_operators_report(target_profile)
- assert str(exc.value) == exception_msg
-
- if required_calls:
- with patch(required_calls) as mock_method:
- generate_supported_operators_report(target_profile)
- mock_method.assert_called_once()
diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py
index bd50945..700534f 100644
--- a/tests/test_backend_config.py
+++ b/tests/test_backend_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend config module."""
from mlia.backend.config import BackendConfiguration
@@ -20,14 +20,14 @@ def test_system() -> None:
def test_backend_config() -> None:
"""Test the class 'BackendConfiguration'."""
cfg = BackendConfiguration(
- [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM
+ [AdviceCategory.COMPATIBILITY], [System.CURRENT], BackendType.CUSTOM
)
- assert cfg.supported_advice == [AdviceCategory.OPERATORS]
+ assert cfg.supported_advice == [AdviceCategory.COMPATIBILITY]
assert cfg.supported_systems == [System.CURRENT]
assert cfg.type == BackendType.CUSTOM
assert str(cfg)
assert cfg.is_supported()
- assert cfg.is_supported(advice=AdviceCategory.OPERATORS)
+ assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY)
assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE)
assert cfg.is_supported(check_system=True)
assert cfg.is_supported(check_system=False)
@@ -37,6 +37,6 @@ def test_backend_config() -> None:
cfg.supported_systems = [UNSUPPORTED_SYSTEM]
assert not cfg.is_supported(check_system=True)
assert cfg.is_supported(check_system=False)
- assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True)
- assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False)
+ assert not cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=True)
+ assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=False)
assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False)
diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py
index 31a20a0..703e699 100644
--- a/tests/test_backend_registry.py
+++ b/tests/test_backend_registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend registry module."""
from __future__ import annotations
@@ -18,7 +18,7 @@ from mlia.core.common import AdviceCategory
(
(
"ArmNNTFLiteDelegate",
- [AdviceCategory.OPERATORS],
+ [AdviceCategory.COMPATIBILITY],
None,
BackendType.BUILTIN,
),
@@ -36,14 +36,14 @@ from mlia.core.common import AdviceCategory
),
(
"TOSA-Checker",
- [AdviceCategory.OPERATORS],
+ [AdviceCategory.COMPATIBILITY],
[System.LINUX_AMD64],
BackendType.WHEEL,
),
(
"Vela",
[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.PERFORMANCE,
AdviceCategory.OPTIMIZATION,
],
diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py
new file mode 100644
index 0000000..13514a5
--- /dev/null
+++ b/tests/test_cli_command_validators.py
@@ -0,0 +1,167 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.command_validators module."""
+from __future__ import annotations
+
+import argparse
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.command_validators import validate_backend
+from mlia.cli.command_validators import validate_check_target_profile
+
+
+@pytest.mark.parametrize(
+ "target_profile, category, expected_warnings, sys_exits",
+ [
+ ["ethos-u55-256", {"compatibility", "performance"}, [], False],
+ ["ethos-u55-256", {"compatibility"}, [], False],
+ ["ethos-u55-256", {"performance"}, [], False],
+ [
+ "tosa",
+ {"compatibility", "performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile tosa."
+ )
+ ],
+ False,
+ ],
+ [
+ "tosa",
+ {"performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile tosa. No operation was performed."
+ )
+ ],
+ True,
+ ],
+ ["tosa", "compatibility", [], False],
+ [
+ "cortex-a",
+ {"performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile cortex-a. "
+ "No operation was performed."
+ )
+ ],
+ True,
+ ],
+ [
+ "cortex-a",
+ {"compatibility", "performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile cortex-a."
+ )
+ ],
+ False,
+ ],
+ ["cortex-a", "compatibility", [], False],
+ ],
+)
+def test_validate_check_target_profile(
+ caplog: pytest.LogCaptureFixture,
+ target_profile: str,
+ category: set[str],
+ expected_warnings: list[str],
+ sys_exits: bool,
+) -> None:
+ """Test outcomes of category dependent target profile validation."""
+ # Capture if program terminates
+ if sys_exits:
+ with pytest.raises(SystemExit) as sys_ex:
+ validate_check_target_profile(target_profile, category)
+ assert sys_ex.value.code == 0
+ return
+
+ validate_check_target_profile(target_profile, category)
+
+ log_records = caplog.records
+ # Get all log records with level 30 (warning level)
+ warning_messages = {x.message for x in log_records if x.levelno == 30}
+ # Ensure the warnings coincide with the expected ones
+ assert warning_messages == set(expected_warnings)
+
+
+@pytest.mark.parametrize(
+ "input_target_profile, input_backends, throws_exception,"
+ "exception_message, output_backends",
+ [
+ [
+ "tosa",
+ ["Vela"],
+ True,
+ "Vela backend not supported with target-profile tosa.",
+ None,
+ ],
+ [
+ "tosa",
+ ["Corstone-300, Vela"],
+ True,
+ "Corstone-300, Vela backend not supported with target-profile tosa.",
+ None,
+ ],
+ [
+ "cortex-a",
+ ["Corstone-310", "tosa-checker"],
+ True,
+ "Corstone-310, tosa-checker backend not supported "
+ "with target-profile cortex-a.",
+ None,
+ ],
+ [
+ "ethos-u55-256",
+ ["tosa-checker", "Corstone-310"],
+ True,
+ "tosa-checker backend not supported with target-profile ethos-u55-256.",
+ None,
+ ],
+ ["tosa", None, False, None, ["tosa-checker"]],
+ ["cortex-a", None, False, None, ["armnn-tflitedelegate"]],
+ ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]],
+ ["cortex-a", ["armnn-tflitedelegate"], False, None, ["armnn-tflitedelegate"]],
+ [
+ "ethos-u55-256",
+ ["Vela", "Corstone-300"],
+ False,
+ None,
+ ["Vela", "Corstone-300"],
+ ],
+ [
+ "ethos-u55-256",
+ None,
+ False,
+ None,
+ ["Vela", "Corstone-300"],
+ ],
+ ],
+)
+def test_validate_backend(
+ monkeypatch: pytest.MonkeyPatch,
+ input_target_profile: str,
+ input_backends: list[str] | None,
+ throws_exception: bool,
+ exception_message: str,
+ output_backends: list[str] | None,
+) -> None:
+ """Test backend validation with target-profiles and backends."""
+ monkeypatch.setattr(
+ "mlia.cli.config.get_available_backends",
+ MagicMock(return_value=["Vela", "Corstone-300"]),
+ )
+
+ if throws_exception:
+ with pytest.raises(argparse.ArgumentError) as err:
+ validate_backend(input_target_profile, input_backends)
+ assert str(err.value.message) == exception_message
+ return
+
+ assert validate_backend(input_target_profile, input_backends) == output_backends
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index aed5c42..03ee9d2 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.commands module."""
from __future__ import annotations
@@ -14,9 +14,8 @@ from mlia.backend.manager import DefaultInstallationManager
from mlia.cli.commands import backend_install
from mlia.cli.commands import backend_list
from mlia.cli.commands import backend_uninstall
-from mlia.cli.commands import operators
-from mlia.cli.commands import optimization
-from mlia.cli.commands import performance
+from mlia.cli.commands import check
+from mlia.cli.commands import optimize
from mlia.core.context import ExecutionContext
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import MemoryUsage
@@ -27,7 +26,7 @@ from mlia.target.ethos_u.performance import PerformanceMetrics
def test_operators_expected_parameters(sample_context: ExecutionContext) -> None:
"""Test operators command wrong parameters."""
with pytest.raises(Exception, match="Model is not provided"):
- operators(sample_context, "ethos-u55-256")
+ check(sample_context, "ethos-u55-256")
def test_performance_unknown_target(
@@ -35,93 +34,45 @@ def test_performance_unknown_target(
) -> None:
"""Test that command should fail if unknown target passed."""
with pytest.raises(Exception, match="Unable to find target profile unknown"):
- performance(
- sample_context, model=str(test_tflite_model), target_profile="unknown"
+ check(
+ sample_context,
+ model=str(test_tflite_model),
+ target_profile="unknown",
+ performance=True,
)
@pytest.mark.parametrize(
- "target_profile, optimization_type, optimization_target, expected_error",
+ "target_profile, pruning, clustering, pruning_target, clustering_target",
[
- [
- "ethos-u55-256",
- None,
- "0.5",
- pytest.raises(Exception, match="Optimization type is not provided"),
- ],
- [
- "ethos-u65-512",
- "unknown",
- "16",
- pytest.raises(Exception, match="Unsupported optimization type: unknown"),
- ],
- [
- "ethos-u55-256",
- "pruning",
- None,
- pytest.raises(Exception, match="Optimization target is not provided"),
- ],
- [
- "ethos-u65-512",
- "clustering",
- None,
- pytest.raises(Exception, match="Optimization target is not provided"),
- ],
- [
- "unknown",
- "clustering",
- "16",
- pytest.raises(Exception, match="Unable to find target profile unknown"),
- ],
- ],
-)
-def test_opt_expected_parameters(
- sample_context: ExecutionContext,
- target_profile: str,
- monkeypatch: pytest.MonkeyPatch,
- optimization_type: str,
- optimization_target: str,
- expected_error: Any,
- test_keras_model: Path,
-) -> None:
- """Test that command should fail if no or unknown optimization type provided."""
- mock_performance_estimation(monkeypatch)
-
- with expected_error:
- optimization(
- ctx=sample_context,
- target_profile=target_profile,
- model=str(test_keras_model),
- optimization_type=optimization_type,
- optimization_target=optimization_target,
- )
-
-
-@pytest.mark.parametrize(
- "target_profile, optimization_type, optimization_target",
- [
- ["ethos-u55-256", "pruning", "0.5"],
- ["ethos-u65-512", "clustering", "32"],
- ["ethos-u55-256", "pruning,clustering", "0.5,32"],
+ ["ethos-u55-256", True, False, 0.5, None],
+ ["ethos-u65-512", False, True, 0.5, 32],
+ ["ethos-u55-256", True, True, 0.5, None],
+ ["ethos-u55-256", False, False, 0.5, None],
+ ["ethos-u55-256", False, True, "invalid", 32],
],
)
def test_opt_valid_optimization_target(
target_profile: str,
sample_context: ExecutionContext,
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
monkeypatch: pytest.MonkeyPatch,
test_keras_model: Path,
) -> None:
"""Test that command should not fail with valid optimization targets."""
mock_performance_estimation(monkeypatch)
- optimization(
+ optimize(
ctx=sample_context,
target_profile=target_profile,
model=str(test_keras_model),
- optimization_type=optimization_type,
- optimization_target=optimization_target,
+ pruning=pruning,
+ clustering=clustering,
+ pruning_target=pruning_target,
+ clustering_target=clustering_target,
)
diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py
index 1a7cb3f..b007052 100644
--- a/tests/test_cli_config.py
+++ b/tests/test_cli_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.config module."""
from __future__ import annotations
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
import pytest
-from mlia.cli.config import get_default_backends
+from mlia.cli.config import get_ethos_u_default_backends
from mlia.cli.config import is_corstone_backend
@@ -29,7 +29,7 @@ from mlia.cli.config import is_corstone_backend
],
],
)
-def test_get_default_backends(
+def test_get_ethos_u_default_backends(
monkeypatch: pytest.MonkeyPatch,
available_backends: list[str],
expected_default_backends: list[str],
@@ -40,7 +40,7 @@ def test_get_default_backends(
MagicMock(return_value=available_backends),
)
- assert get_default_backends() == expected_default_backends
+ assert get_ethos_u_default_backends() == expected_default_backends
def test_is_corstone_backend() -> None:
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index c8aeebe..8f7e4b0 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
from __future__ import annotations
@@ -28,40 +28,39 @@ class TestCliActionResolver:
{},
[
"Note: you will need a Keras model for that.",
- "For example: mlia optimization --optimization-type "
- "pruning,clustering --optimization-target 0.5,32 "
- "/path/to/keras_model",
- "For more info: mlia optimization --help",
+ "For example: mlia optimize /path/to/keras_model "
+ "--pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
],
],
[
{"model": "model.h5"},
{},
[
- "For example: mlia optimization --optimization-type "
- "pruning,clustering --optimization-target 0.5,32 model.h5",
- "For more info: mlia optimization --help",
+ "For example: mlia optimize model.h5 --pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
],
],
[
{"model": "model.h5"},
{"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
[
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- "mlia optimization --optimization-type pruning "
- "--optimization-target 0.5 model.h5",
+ "mlia optimize model.h5 --pruning "
+ "--pruning-target 0.5",
],
],
[
{"model": "model.h5", "target_profile": "target_profile"},
{"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
[
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- "mlia optimization --optimization-type pruning "
- "--optimization-target 0.5 "
- "--target-profile target_profile model.h5",
+ "mlia optimize model.h5 --target-profile target_profile "
+ "--pruning --pruning-target 0.5",
],
],
],
@@ -76,20 +75,11 @@ class TestCliActionResolver:
assert resolver.apply_optimizations(**params) == expected_result
@staticmethod
- def test_supported_operators_info() -> None:
- """Test supported operators info."""
- resolver = CLIActionResolver({})
- assert resolver.supported_operators_info() == [
- "For guidance on supported operators, run: mlia operators "
- "--supported-ops-report",
- ]
-
- @staticmethod
def test_operator_compatibility_details() -> None:
"""Test operator compatibility details info."""
resolver = CLIActionResolver({})
assert resolver.operator_compatibility_details() == [
- "For more details, run: mlia operators --help"
+ "For more details, run: mlia check --help"
]
@staticmethod
@@ -97,7 +87,7 @@ class TestCliActionResolver:
"""Test optimization details info."""
resolver = CLIActionResolver({})
assert resolver.optimization_details() == [
- "For more info, see: mlia optimization --help"
+ "For more info, see: mlia optimize --help"
]
@staticmethod
@@ -109,19 +99,12 @@ class TestCliActionResolver:
[],
],
[
- {"model": "model.tflite"},
- [
- "Check the estimated performance by running the "
- "following command: ",
- "mlia performance model.tflite",
- ],
- ],
- [
{"model": "model.tflite", "target_profile": "target_profile"},
[
"Check the estimated performance by running the "
"following command: ",
- "mlia performance --target-profile target_profile model.tflite",
+ "mlia check model.tflite "
+ "--target-profile target_profile --performance",
],
],
],
@@ -142,17 +125,10 @@ class TestCliActionResolver:
[],
],
[
- {"model": "model.tflite"},
- [
- "Try running the following command to verify that:",
- "mlia operators model.tflite",
- ],
- ],
- [
{"model": "model.tflite", "target_profile": "target_profile"},
[
"Try running the following command to verify that:",
- "mlia operators --target-profile target_profile model.tflite",
+ "mlia check model.tflite --target-profile target_profile",
],
],
],
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 925f1e4..5a9c0c9 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for main module."""
from __future__ import annotations
@@ -19,7 +19,6 @@ from mlia.backend.errors import BackendUnavailableError
from mlia.cli.main import backend_main
from mlia.cli.main import CommandInfo
from mlia.cli.main import main
-from mlia.core.context import ExecutionContext
from mlia.core.errors import InternalError
from tests.utils.logging import clear_loggers
@@ -62,35 +61,23 @@ def test_command_info(is_default: bool, expected_command_help: str) -> None:
assert command_info.command_help == expected_command_help
-def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+def test_default_command(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test adding default command."""
- def mock_command(
- func_mock: MagicMock, name: str, with_working_dir: bool
- ) -> Callable[..., None]:
+ def mock_command(func_mock: MagicMock, name: str) -> Callable[..., None]:
"""Mock cli command."""
def sample_cmd_1(*args: Any, **kwargs: Any) -> None:
"""Sample command."""
func_mock(*args, **kwargs)
- def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None:
- """Another sample command."""
- func_mock(ctx=ctx, **kwargs)
-
- ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1
+ ret_func = sample_cmd_1
ret_func.__name__ = name
- return ret_func # type: ignore
+ return ret_func
- default_command = MagicMock()
non_default_command = MagicMock()
- def default_command_params(parser: argparse.ArgumentParser) -> None:
- """Add parameters for default command."""
- parser.add_argument("--sample")
- parser.add_argument("--default_arg", default="123")
-
def non_default_command_params(parser: argparse.ArgumentParser) -> None:
"""Add parameters for non default command."""
parser.add_argument("--param")
@@ -100,15 +87,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
MagicMock(
return_value=[
CommandInfo(
- func=mock_command(default_command, "default_command", True),
- aliases=["command1"],
- opt_groups=[default_command_params],
- is_default=True,
- ),
- CommandInfo(
- func=mock_command(
- non_default_command, "non_default_command", False
- ),
+ func=mock_command(non_default_command, "non_default_command"),
aliases=["command2"],
opt_groups=[non_default_command_params],
is_default=False,
@@ -117,11 +96,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
),
)
- tmp_working_dir = str(tmp_path)
- main(["--working-dir", tmp_working_dir, "--sample", "1"])
main(["command2", "--param", "test"])
-
- default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123")
non_default_command.assert_called_once_with(param="test")
@@ -140,134 +115,168 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
"params, expected_call",
[
[
- ["operators", "sample_model.tflite"],
+ ["check", "sample_model.tflite", "--target-profile", "ethos-u55-256"],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.tflite",
+ compatibility=False,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
[
- ["ops", "sample_model.tflite"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-256",
- model="sample_model.tflite",
- output=None,
- supported_ops_report=False,
- ),
- ],
- [
- ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
+ ["check", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
call(
ctx=ANY,
target_profile="ethos-u55-128",
model="sample_model.tflite",
+ compatibility=False,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
[
- ["operators"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-256",
- model=None,
- output=None,
- supported_ops_report=False,
- ),
- ],
- [
- ["operators", "--supported-ops-report"],
+ [
+ "check",
+ "sample_model.h5",
+ "--performance",
+ "--compatibility",
+ "--target-profile",
+ "ethos-u55-256",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
- model=None,
+ model="sample_model.h5",
output=None,
- supported_ops_report=True,
+ json=False,
+ compatibility=True,
+ performance=True,
+ backend=None,
),
],
[
[
- "all_tests",
+ "check",
"sample_model.h5",
- "--optimization-type",
- "pruning",
- "--optimization-target",
- "0.5",
+ "--performance",
+ "--target-profile",
+ "ethos-u55-256",
+ "--output",
+ "result.json",
+ "--json",
],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning",
- optimization_target="0.5",
- output=None,
- evaluate_on=["Vela"],
+ performance=True,
+ compatibility=False,
+ output=Path("result.json"),
+ json=True,
+ backend=None,
),
],
[
- ["sample_model.h5"],
+ [
+ "check",
+ "sample_model.h5",
+ "--performance",
+ "--target-profile",
+ "ethos-u55-128",
+ ],
call(
ctx=ANY,
- target_profile="ethos-u55-256",
+ target_profile="ethos-u55-128",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ compatibility=False,
+ performance=True,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["performance", "sample_model.h5", "--output", "result.json"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--clustering",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- output="result.json",
- evaluate_on=["Vela"],
- ),
- ],
- [
- ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-128",
- model="sample_model.h5",
+ pruning=True,
+ clustering=True,
+ pruning_target=None,
+ clustering_target=None,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["optimization", "sample_model.h5"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--clustering",
+ "--pruning-target",
+ "0.5",
+ "--clustering-target",
+ "32",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ pruning=True,
+ clustering=True,
+ pruning_target=0.5,
+ clustering_target=32,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--backend",
+ "some_backend",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ pruning=True,
+ clustering=False,
+ pruning_target=None,
+ clustering_target=None,
output=None,
- evaluate_on=["some_backend"],
+ json=False,
+ backend=["some_backend"],
),
],
[
[
- "operators",
+ "check",
"sample_model.h5",
+ "--compatibility",
"--target-profile",
"cortex-a",
],
@@ -275,8 +284,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
ctx=ANY,
target_profile="cortex-a",
model="sample_model.h5",
+ compatibility=True,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
],
@@ -288,15 +300,11 @@ def test_commands_execution(
mock = MagicMock()
monkeypatch.setattr(
- "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
- )
-
- monkeypatch.setattr(
"mlia.cli.options.get_available_backends",
MagicMock(return_value=["Vela", "some_backend"]),
)
- for command in ["all_tests", "operators", "performance", "optimization"]:
+ for command in ["check", "optimize"]:
monkeypatch.setattr(
f"mlia.cli.main.{command}",
wrap_mock_command(mock, getattr(mlia.cli.main, command)),
@@ -335,15 +343,15 @@ def test_commands_execution_backend_main(
@pytest.mark.parametrize(
- "verbose, exc_mock, expected_output",
+ "debug, exc_mock, expected_output",
[
[
True,
MagicMock(side_effect=Exception("Error")),
[
"Execution finished with error: Error",
- f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
- "for more details",
+ "Please check the log files in the /tmp/mlia-",
+ "/logs for more details",
],
],
[
@@ -351,8 +359,8 @@ def test_commands_execution_backend_main(
MagicMock(side_effect=Exception("Error")),
[
"Execution finished with error: Error",
- f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
- "for more details, or enable verbose mode (--verbose)",
+ "Please check the log files in the /tmp/mlia-",
+ "/logs for more details, or enable debug mode (--debug)",
],
],
[
@@ -389,18 +397,18 @@ def test_commands_execution_backend_main(
],
],
)
-def test_verbose_output(
+def test_debug_output(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture,
- verbose: bool,
+ debug: bool,
exc_mock: MagicMock,
expected_output: list[str],
) -> None:
- """Test flag --verbose."""
+ """Test flag --debug."""
def command_params(parser: argparse.ArgumentParser) -> None:
"""Add parameters for non default command."""
- parser.add_argument("--verbose", action="store_true")
+ parser.add_argument("--debug", action="store_true")
def command() -> None:
"""Run test command."""
@@ -420,8 +428,8 @@ def test_verbose_output(
)
params = ["command"]
- if verbose:
- params.append("--verbose")
+ if debug:
+ params.append("--debug")
exit_code = main(params)
assert exit_code == 1
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index d75f7c0..a889a93 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module options."""
from __future__ import annotations
@@ -13,14 +13,19 @@ import pytest
from mlia.cli.options import add_output_options
from mlia.cli.options import get_target_profile_opts
from mlia.cli.options import parse_optimization_parameters
+from mlia.cli.options import parse_output_parameters
+from mlia.core.common import FormattedFilePath
@pytest.mark.parametrize(
- "optimization_type, optimization_target, expected_error, expected_result",
+ "pruning, clustering, pruning_target, clustering_target, expected_error,"
+ "expected_result",
[
- (
- "pruning",
- "0.5",
+ [
+ False,
+ False,
+ None,
+ None,
does_not_raise(),
[
dict(
@@ -29,39 +34,40 @@ from mlia.cli.options import parse_optimization_parameters
layers_to_optimize=None,
)
],
- ),
- (
- "clustering",
- "32",
+ ],
+ [
+ True,
+ False,
+ None,
+ None,
does_not_raise(),
[
dict(
- optimization_type="clustering",
- optimization_target=32.0,
+ optimization_type="pruning",
+ optimization_target=0.5,
layers_to_optimize=None,
)
],
- ),
- (
- "pruning,clustering",
- "0.5,32",
+ ],
+ [
+ False,
+ True,
+ None,
+ None,
does_not_raise(),
[
dict(
- optimization_type="pruning",
- optimization_target=0.5,
- layers_to_optimize=None,
- ),
- dict(
optimization_type="clustering",
- optimization_target=32.0,
+ optimization_target=32,
layers_to_optimize=None,
- ),
+ )
],
- ),
- (
- "pruning, clustering",
- "0.5, 32",
+ ],
+ [
+ True,
+ True,
+ None,
+ None,
does_not_raise(),
[
dict(
@@ -71,50 +77,66 @@ from mlia.cli.options import parse_optimization_parameters
),
dict(
optimization_type="clustering",
- optimization_target=32.0,
+ optimization_target=32,
layers_to_optimize=None,
),
],
- ),
- (
- "pruning,clustering",
- "0.5",
- pytest.raises(
- Exception, match="Wrong number of optimization targets and types"
- ),
- None,
- ),
- (
- "",
- "0.5",
- pytest.raises(Exception, match="Optimization type is not provided"),
+ ],
+ [
+ False,
+ False,
+ 0.4,
None,
- ),
- (
- "pruning,clustering",
- "",
- pytest.raises(Exception, match="Optimization target is not provided"),
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.4,
+ layers_to_optimize=None,
+ )
+ ],
+ ],
+ [
+ False,
+ False,
None,
- ),
- (
- "pruning,",
- "0.5,abc",
+ 32,
pytest.raises(
- Exception, match="Non numeric value for the optimization target"
+ argparse.ArgumentError,
+ match="To enable clustering optimization you need to include "
+ "the `--clustering` flag in your command.",
),
None,
- ),
+ ],
+ [
+ False,
+ True,
+ None,
+ 32.2,
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.2,
+ layers_to_optimize=None,
+ )
+ ],
+ ],
],
)
def test_parse_optimization_parameters(
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
expected_error: Any,
expected_result: Any,
) -> None:
"""Test function parse_optimization_parameters."""
with expected_error:
- result = parse_optimization_parameters(optimization_type, optimization_target)
+ result = parse_optimization_parameters(
+ pruning, clustering, pruning_target, clustering_target
+ )
assert result == expected_result
@@ -155,28 +177,41 @@ def test_output_options(output_parameters: list[str], expected_path: str) -> Non
add_output_options(parser)
args = parser.parse_args(output_parameters)
- assert args.output == expected_path
+ assert str(args.output) == expected_path
@pytest.mark.parametrize(
- "output_filename",
+ "path, json, expected_error, output",
[
- "report.txt",
- "report.TXT",
- "report",
- "report.pdf",
+ [
+ None,
+ True,
+ pytest.raises(
+ argparse.ArgumentError,
+ match=r"To enable JSON output you need to specify the output path. "
+ r"\(e.g. --output out.json --json\)",
+ ),
+ None,
+ ],
+ [None, False, does_not_raise(), None],
+ [
+ Path("test_path"),
+ False,
+ does_not_raise(),
+ FormattedFilePath(Path("test_path"), "plain_text"),
+ ],
+ [
+ Path("test_path"),
+ True,
+ does_not_raise(),
+ FormattedFilePath(Path("test_path"), "json"),
+ ],
],
)
-def test_output_options_bad_parameters(
- output_filename: str, capsys: pytest.CaptureFixture
+def test_parse_output_parameters(
+ path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None
) -> None:
- """Test that args parsing should fail if format is not supported."""
- parser = argparse.ArgumentParser()
- add_output_options(parser)
-
- with pytest.raises(SystemExit):
- parser.parse_args(["--output", output_filename])
-
- err_output = capsys.readouterr().err
- suffix = Path(output_filename).suffix[1:]
- assert f"Unsupported format '{suffix}'" in err_output
+ """Test parsing for output parameters."""
+ with expected_error:
+ formatted_output = parse_output_parameters(path, json)
+ assert formatted_output == output
diff --git a/tests/test_core_advice_generation.py b/tests/test_core_advice_generation.py
index 3d985eb..2e0038f 100644
--- a/tests/test_core_advice_generation.py
+++ b/tests/test_core_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module advice_generation."""
from __future__ import annotations
@@ -35,17 +35,17 @@ def test_advice_generation() -> None:
"category, expected_advice",
[
[
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[Advice(["Good advice!"])],
],
[
- AdviceCategory.PERFORMANCE,
+ {AdviceCategory.PERFORMANCE},
[],
],
],
)
def test_advice_category_decorator(
- category: AdviceCategory,
+ category: set[AdviceCategory],
expected_advice: list[Advice],
sample_context: Context,
) -> None:
@@ -54,7 +54,7 @@ def test_advice_category_decorator(
class SampleAdviceProducer(FactBasedAdviceProducer):
"""Sample advice producer."""
- @advice_category(AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def produce_advice(self, data_item: DataItem) -> None:
"""Produce the advice."""
self.add_advice(["Good advice!"])
diff --git a/tests/test_core_context.py b/tests/test_core_context.py
index 44eb976..dcdbef3 100644
--- a/tests/test_core_context.py
+++ b/tests/test_core_context.py
@@ -1,17 +1,53 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module context."""
+from __future__ import annotations
+
from pathlib import Path
+import pytest
+
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.events import DefaultEventPublisher
+@pytest.mark.parametrize(
+ "context_advice_category, expected_enabled_categories",
+ [
+ [
+ {
+ AdviceCategory.COMPATIBILITY,
+ },
+ [AdviceCategory.COMPATIBILITY],
+ ],
+ [
+ {
+ AdviceCategory.PERFORMANCE,
+ },
+ [AdviceCategory.PERFORMANCE],
+ ],
+ [
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE},
+ [AdviceCategory.PERFORMANCE, AdviceCategory.COMPATIBILITY],
+ ],
+ ],
+)
+def test_execution_context_category_enabled(
+ context_advice_category: set[AdviceCategory],
+ expected_enabled_categories: list[AdviceCategory],
+) -> None:
+ """Test category enabled method of execution context."""
+ for category in expected_enabled_categories:
+ assert ExecutionContext(
+ advice_category=context_advice_category
+ ).category_enabled(category)
+
+
def test_execution_context(tmpdir: str) -> None:
"""Test execution context."""
publisher = DefaultEventPublisher()
- category = AdviceCategory.OPERATORS
+ category = {AdviceCategory.COMPATIBILITY}
context = ExecutionContext(
advice_category=category,
@@ -35,13 +71,13 @@ def test_execution_context(tmpdir: str) -> None:
assert str(context) == (
f"ExecutionContext: "
f"working_dir={tmpdir}, "
- "advice_category=OPERATORS, "
+ "advice_category={'COMPATIBILITY'}, "
"config_parameters={'param': 'value'}, "
"verbose=True"
)
context_with_default_params = ExecutionContext(working_dir=tmpdir)
- assert context_with_default_params.advice_category is AdviceCategory.ALL
+ assert context_with_default_params.advice_category == {AdviceCategory.COMPATIBILITY}
assert context_with_default_params.config_parameters is None
assert context_with_default_params.event_handlers is None
assert isinstance(
@@ -55,7 +91,7 @@ def test_execution_context(tmpdir: str) -> None:
expected_str = (
f"ExecutionContext: working_dir={tmpdir}, "
- "advice_category=ALL, "
+ "advice_category={'COMPATIBILITY'}, "
"config_parameters=None, "
"verbose=False"
)
diff --git a/tests/test_core_helpers.py b/tests/test_core_helpers.py
index 8577617..03ec3f0 100644
--- a/tests/test_core_helpers.py
+++ b/tests/test_core_helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
from mlia.core.helpers import APIActionResolver
@@ -10,7 +10,6 @@ def test_api_action_resolver() -> None:
# pylint: disable=use-implicit-booleaness-not-comparison
assert helper.apply_optimizations() == []
- assert helper.supported_operators_info() == []
assert helper.check_performance() == []
assert helper.check_operator_compatibility() == []
assert helper.operator_compatibility_details() == []
diff --git a/tests/test_core_mixins.py b/tests/test_core_mixins.py
index 3834fb3..47ed815 100644
--- a/tests/test_core_mixins.py
+++ b/tests/test_core_mixins.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module mixins."""
import pytest
@@ -36,7 +36,7 @@ class TestParameterResolverMixin:
self.context = sample_context
self.context.update(
- advice_category=AdviceCategory.OPERATORS,
+ advice_category={AdviceCategory.COMPATIBILITY},
event_handlers=[],
config_parameters={"section": {"param": 123}},
)
@@ -83,7 +83,7 @@ class TestParameterResolverMixin:
"""Init sample object."""
self.context = sample_context
self.context.update(
- advice_category=AdviceCategory.OPERATORS,
+ advice_category={AdviceCategory.COMPATIBILITY},
event_handlers=[],
config_parameters={"section": ["param"]},
)
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index feff5cc..7b26173 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reporting module."""
from __future__ import annotations
@@ -13,11 +13,8 @@ from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
-from mlia.core.reporting import resolve_output_format
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
-from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
from mlia.utils.console import remove_ascii_codes
@@ -338,20 +335,3 @@ Single row example:
alias="simple_row_example",
)
wrong_single_row.to_plain_text()
-
-
-@pytest.mark.parametrize(
- "output, expected_output_format",
- [
- [None, "plain_text"],
- ["", "plain_text"],
- ["some_file", "plain_text"],
- ["some_format.some_ext", "plain_text"],
- ["output.json", "json"],
- ],
-)
-def test_resolve_output_format(
- output: PathOrFileLike | None, expected_output_format: OutputFormat
-) -> None:
- """Test function resolve_output_format."""
- assert resolve_output_format(output) == expected_output_format
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
index 66ebed6..48f0a58 100644
--- a/tests/test_target_config.py
+++ b/tests/test_target_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend config module."""
from __future__ import annotations
@@ -25,7 +25,7 @@ def test_ip_config() -> None:
(
(None, False, True),
(None, True, True),
- (AdviceCategory.OPERATORS, True, True),
+ (AdviceCategory.COMPATIBILITY, True, True),
(AdviceCategory.OPTIMIZATION, True, False),
),
)
@@ -42,7 +42,7 @@ def test_target_info(
backend_registry.register(
"backend",
BackendConfiguration(
- [AdviceCategory.OPERATORS],
+ [AdviceCategory.COMPATIBILITY],
[System.CURRENT],
BackendType.BUILTIN,
),
diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py
index 6effe4c..1997c52 100644
--- a/tests/test_target_cortex_a_advice_generation.py
+++ b/tests/test_target_cortex_a_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for advice generation."""
from __future__ import annotations
@@ -31,7 +31,7 @@ BACKEND_INFO = (
[
[
ModelIsNotCortexACompatible(BACKEND_INFO, {"UNSUPPORTED_OP"}, {}),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -61,7 +61,7 @@ BACKEND_INFO = (
)
},
),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -93,7 +93,7 @@ BACKEND_INFO = (
],
[
ModelIsCortexACompatible(BACKEND_INFO),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -108,7 +108,7 @@ BACKEND_INFO = (
flex_ops=["flex_op1", "flex_op2"],
custom_ops=["custom_op1", "custom_op2"],
),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -142,7 +142,7 @@ BACKEND_INFO = (
],
[
ModelIsNotTFLiteCompatible(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -154,7 +154,7 @@ BACKEND_INFO = (
],
[
ModelHasCustomOperators(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -166,7 +166,7 @@ BACKEND_INFO = (
],
[
TFLiteCompatibilityCheckFailed(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -181,7 +181,7 @@ BACKEND_INFO = (
def test_cortex_a_advice_producer(
tmpdir: str,
input_data: DataItem,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
expected_advice: list[Advice],
) -> None:
"""Test Cortex-A advice producer."""
diff --git a/tests/test_target_ethos_u_advice_generation.py b/tests/test_target_ethos_u_advice_generation.py
index 1569592..e93eeba 100644
--- a/tests/test_target_ethos_u_advice_generation.py
+++ b/tests/test_target_ethos_u_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U advice generation."""
from __future__ import annotations
@@ -28,7 +28,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
[
[
AllOperatorsSupportedOnNPU(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
APIActionResolver(),
[
Advice(
@@ -41,7 +41,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
],
[
AllOperatorsSupportedOnNPU(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
CLIActionResolver(
{
"target_profile": "sample_target",
@@ -55,15 +55,15 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
"run completely on NPU.",
"Check the estimated performance by running the "
"following command: ",
- "mlia performance --target-profile sample_target "
- "sample_model.tflite",
+ "mlia check sample_model.tflite --target-profile sample_target "
+ "--performance",
]
)
],
],
[
HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
APIActionResolver(),
[
Advice(
@@ -78,7 +78,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
],
[
HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
CLIActionResolver({}),
[
Advice(
@@ -87,15 +87,13 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
"OP1,OP2,OP3.",
"Using operators that are supported by the NPU will "
"improve performance.",
- "For guidance on supported operators, run: mlia operators "
- "--supported-ops-report",
]
)
],
],
[
HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
APIActionResolver(),
[
Advice(
@@ -110,7 +108,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
],
[
HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
CLIActionResolver({}),
[
Advice(
@@ -138,7 +136,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -178,7 +176,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
CLIActionResolver({"model": "sample_model.h5"}),
[
Advice(
@@ -192,10 +190,10 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
"You can try to push the optimization target higher "
"(e.g. pruning: 0.6) "
"to check if those results can be further improved.",
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- "mlia optimization --optimization-type pruning "
- "--optimization-target 0.6 sample_model.h5",
+ "mlia optimize sample_model.h5 --pruning "
+ "--pruning-target 0.6",
]
),
Advice(
@@ -225,7 +223,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -267,7 +265,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -304,7 +302,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -354,7 +352,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[], # no advice for more than one optimization result
],
@@ -364,7 +362,7 @@ def test_ethosu_advice_producer(
tmpdir: str,
input_data: DataItem,
expected_advice: list[Advice],
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory] | None,
action_resolver: ActionResolver,
) -> None:
"""Test Ethos-U Advice producer."""
@@ -386,17 +384,17 @@ def test_ethosu_advice_producer(
"advice_category, action_resolver, expected_advice",
[
[
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE},
None,
[],
],
[
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
None,
[],
],
[
- AdviceCategory.PERFORMANCE,
+ {AdviceCategory.PERFORMANCE},
APIActionResolver(),
[
Advice(
@@ -414,31 +412,33 @@ def test_ethosu_advice_producer(
],
],
[
- AdviceCategory.PERFORMANCE,
- CLIActionResolver({"model": "test_model.h5"}),
+ {AdviceCategory.PERFORMANCE},
+ CLIActionResolver(
+ {"model": "test_model.h5", "target_profile": "sample_target"}
+ ),
[
Advice(
[
"You can improve the inference time by using only operators "
"that are supported by the NPU.",
"Try running the following command to verify that:",
- "mlia operators test_model.h5",
+ "mlia check test_model.h5 --target-profile sample_target",
]
),
Advice(
[
"Check if you can improve the performance by applying "
"tooling techniques to your model.",
- "For example: mlia optimization --optimization-type "
- "pruning,clustering --optimization-target 0.5,32 "
- "test_model.h5",
- "For more info: mlia optimization --help",
+ "For example: mlia optimize test_model.h5 "
+ "--pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
]
),
],
],
[
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -450,14 +450,14 @@ def test_ethosu_advice_producer(
],
],
[
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
CLIActionResolver({"model": "test_model.h5"}),
[
Advice(
[
"For better performance, make sure that all the operators "
"of your final TensorFlow Lite model are supported by the NPU.",
- "For more details, run: mlia operators --help",
+ "For more details, run: mlia check --help",
]
)
],
@@ -466,7 +466,7 @@ def test_ethosu_advice_producer(
)
def test_ethosu_static_advice_producer(
tmpdir: str,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory] | None,
action_resolver: ActionResolver,
expected_advice: list[Advice],
) -> None:
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
index e6ee296..e6028a9 100644
--- a/tests/test_target_registry.py
+++ b/tests/test_target_registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the target registry module."""
from __future__ import annotations
@@ -26,11 +26,11 @@ def test_target_registry(expected_target: str) -> None:
@pytest.mark.parametrize(
("target_name", "expected_advices"),
(
- ("Cortex-A", [AdviceCategory.OPERATORS]),
+ ("Cortex-A", [AdviceCategory.COMPATIBILITY]),
(
"Ethos-U55",
[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.OPTIMIZATION,
AdviceCategory.PERFORMANCE,
],
@@ -38,12 +38,12 @@ def test_target_registry(expected_target: str) -> None:
(
"Ethos-U65",
[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.OPTIMIZATION,
AdviceCategory.PERFORMANCE,
],
),
- ("TOSA", [AdviceCategory.OPERATORS]),
+ ("TOSA", [AdviceCategory.COMPATIBILITY]),
),
)
def test_supported_advice(
@@ -72,7 +72,7 @@ def test_supported_backends(target_name: str, expected_backends: list[str]) -> N
@pytest.mark.parametrize(
("advice", "expected_targets"),
(
- (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]),
+ (AdviceCategory.COMPATIBILITY, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]),
(AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]),
(AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]),
),
diff --git a/tests/test_target_tosa_advice_generation.py b/tests/test_target_tosa_advice_generation.py
index e8e06f8..d5ebbd7 100644
--- a/tests/test_target_tosa_advice_generation.py
+++ b/tests/test_target_tosa_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for advice generation."""
from __future__ import annotations
@@ -19,7 +19,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
[
[
ModelIsNotTOSACompatible(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -31,7 +31,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
],
[
ModelIsTOSACompatible(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[Advice(["Model is fully TOSA compatible."])],
],
],
@@ -39,7 +39,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
def test_tosa_advice_producer(
tmpdir: str,
input_data: DataItem,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
expected_advice: list[Advice],
) -> None:
"""Test TOSA advice producer."""