diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_api.py | 98 | ||||
-rw-r--r-- | tests/test_backend_config.py | 12 | ||||
-rw-r--r-- | tests/test_backend_registry.py | 8 | ||||
-rw-r--r-- | tests/test_cli_command_validators.py | 167 | ||||
-rw-r--r-- | tests/test_cli_commands.py | 97 | ||||
-rw-r--r-- | tests/test_cli_config.py | 8 | ||||
-rw-r--r-- | tests/test_cli_helpers.py | 62 | ||||
-rw-r--r-- | tests/test_cli_main.py | 228 | ||||
-rw-r--r-- | tests/test_cli_options.py | 179 | ||||
-rw-r--r-- | tests/test_core_advice_generation.py | 10 | ||||
-rw-r--r-- | tests/test_core_context.py | 46 | ||||
-rw-r--r-- | tests/test_core_helpers.py | 3 | ||||
-rw-r--r-- | tests/test_core_mixins.py | 6 | ||||
-rw-r--r-- | tests/test_core_reporting.py | 22 | ||||
-rw-r--r-- | tests/test_target_config.py | 6 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advice_generation.py | 18 | ||||
-rw-r--r-- | tests/test_target_ethos_u_advice_generation.py | 70 | ||||
-rw-r--r-- | tests/test_target_registry.py | 12 | ||||
-rw-r--r-- | tests/test_target_tosa_advice_generation.py | 8 |
19 files changed, 584 insertions, 476 deletions
diff --git a/tests/test_api.py b/tests/test_api.py index fbc558b..0bbc3ae 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,15 +1,13 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the API functions.""" from __future__ import annotations from pathlib import Path from unittest.mock import MagicMock -from unittest.mock import patch import pytest -from mlia.api import generate_supported_operators_report from mlia.api import get_advice from mlia.api import get_advisor from mlia.core.common import AdviceCategory @@ -22,63 +20,68 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: """Test getting advice when no target provided.""" with pytest.raises(Exception, match="Target profile is not provided"): - get_advice(None, test_keras_model, "all") # type: ignore + get_advice(None, test_keras_model, {"compatibility"}) # type: ignore def test_get_advice_wrong_category(test_keras_model: Path) -> None: """Test getting advice when wrong advice category provided.""" with pytest.raises(Exception, match="Invalid advice category unknown"): - get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore + get_advice("ethos-u55-256", test_keras_model, {"unknown"}) @pytest.mark.parametrize( "category, context, expected_category", [ [ - "all", + {"compatibility", "optimization"}, None, - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "optimization", + {"optimization"}, None, - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, ], [ - "operators", + {"compatibility"}, None, - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, ], [ - "performance", + {"performance"}, None, - AdviceCategory.PERFORMANCE, + {AdviceCategory.PERFORMANCE}, ], [ - "all", + {"compatibility", "optimization"}, ExecutionContext(), - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "all", - ExecutionContext(advice_category=AdviceCategory.PERFORMANCE), - AdviceCategory.ALL, + {"compatibility", "optimization"}, + ExecutionContext( + advice_category={ + AdviceCategory.COMPATIBILITY, + AdviceCategory.OPTIMIZATION, + } + ), + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "all", + {"compatibility", "optimization"}, ExecutionContext(config_parameters={"param": "value"}), - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "all", + {"compatibility", "optimization"}, ExecutionContext(event_handlers=[MagicMock()]), - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], ], ) def test_get_advice( monkeypatch: pytest.MonkeyPatch, - category: str, + category: set[str], context: ExecutionContext, expected_category: AdviceCategory, test_keras_model: Path, @@ -90,7 +93,7 @@ def test_get_advice( get_advice( "ethos-u55-256", test_keras_model, - category, # type: ignore + category, context=context, ) @@ -111,50 +114,3 @@ def test_get_advisor( tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model)) assert isinstance(tosa_advisor, TOSAInferenceAdvisor) - - -@pytest.mark.parametrize( - ["target_profile", "required_calls", "exception_msg"], - [ - [ - "ethos-u55-128", - "mlia.target.ethos_u.operators.generate_supported_operators_report", - None, - ], - [ - "ethos-u65-256", - "mlia.target.ethos_u.operators.generate_supported_operators_report", - None, - ], - [ - "tosa", - None, - "Generating a supported operators report is not " - "currently supported with TOSA target profile.", - ], - [ - "cortex-a", - None, - "Generating a supported operators report is not " - "currently supported with Cortex-A target profile.", - ], - [ - "Unknown", - None, - "Unable to find target profile Unknown", - ], - ], -) -def test_supported_ops_report_generator( - target_profile: str, required_calls: str | None, exception_msg: str | None -) -> None: - """Test supported operators report generator with different target profiles.""" - if exception_msg: - with pytest.raises(Exception) as exc: - generate_supported_operators_report(target_profile) - assert str(exc.value) == exception_msg - - if required_calls: - with patch(required_calls) as mock_method: - generate_supported_operators_report(target_profile) - mock_method.assert_called_once() diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py index bd50945..700534f 100644 --- a/tests/test_backend_config.py +++ b/tests/test_backend_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend config module.""" from mlia.backend.config import BackendConfiguration @@ -20,14 +20,14 @@ def test_system() -> None: def test_backend_config() -> None: """Test the class 'BackendConfiguration'.""" cfg = BackendConfiguration( - [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM + [AdviceCategory.COMPATIBILITY], [System.CURRENT], BackendType.CUSTOM ) - assert cfg.supported_advice == [AdviceCategory.OPERATORS] + assert cfg.supported_advice == [AdviceCategory.COMPATIBILITY] assert cfg.supported_systems == [System.CURRENT] assert cfg.type == BackendType.CUSTOM assert str(cfg) assert cfg.is_supported() - assert cfg.is_supported(advice=AdviceCategory.OPERATORS) + assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY) assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE) assert cfg.is_supported(check_system=True) assert cfg.is_supported(check_system=False) @@ -37,6 +37,6 @@ def test_backend_config() -> None: cfg.supported_systems = [UNSUPPORTED_SYSTEM] assert not cfg.is_supported(check_system=True) assert cfg.is_supported(check_system=False) - assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True) - assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False) + assert not cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=True) + assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=False) assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False) diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py index 31a20a0..703e699 100644 --- a/tests/test_backend_registry.py +++ b/tests/test_backend_registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend registry module.""" from __future__ import annotations @@ -18,7 +18,7 @@ from mlia.core.common import AdviceCategory ( ( "ArmNNTFLiteDelegate", - [AdviceCategory.OPERATORS], + [AdviceCategory.COMPATIBILITY], None, BackendType.BUILTIN, ), @@ -36,14 +36,14 @@ from mlia.core.common import AdviceCategory ), ( "TOSA-Checker", - [AdviceCategory.OPERATORS], + [AdviceCategory.COMPATIBILITY], [System.LINUX_AMD64], BackendType.WHEEL, ), ( "Vela", [ - AdviceCategory.OPERATORS, + AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION, ], diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py new file mode 100644 index 0000000..13514a5 --- /dev/null +++ b/tests/test_cli_command_validators.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for cli.command_validators module.""" +from __future__ import annotations + +import argparse +from unittest.mock import MagicMock + +import pytest + +from mlia.cli.command_validators import validate_backend +from mlia.cli.command_validators import validate_check_target_profile + + +@pytest.mark.parametrize( + "target_profile, category, expected_warnings, sys_exits", + [ + ["ethos-u55-256", {"compatibility", "performance"}, [], False], + ["ethos-u55-256", {"compatibility"}, [], False], + ["ethos-u55-256", {"performance"}, [], False], + [ + "tosa", + {"compatibility", "performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile tosa." + ) + ], + False, + ], + [ + "tosa", + {"performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile tosa. No operation was performed." + ) + ], + True, + ], + ["tosa", "compatibility", [], False], + [ + "cortex-a", + {"performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile cortex-a. " + "No operation was performed." + ) + ], + True, + ], + [ + "cortex-a", + {"compatibility", "performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile cortex-a." + ) + ], + False, + ], + ["cortex-a", "compatibility", [], False], + ], +) +def test_validate_check_target_profile( + caplog: pytest.LogCaptureFixture, + target_profile: str, + category: set[str], + expected_warnings: list[str], + sys_exits: bool, +) -> None: + """Test outcomes of category dependent target profile validation.""" + # Capture if program terminates + if sys_exits: + with pytest.raises(SystemExit) as sys_ex: + validate_check_target_profile(target_profile, category) + assert sys_ex.value.code == 0 + return + + validate_check_target_profile(target_profile, category) + + log_records = caplog.records + # Get all log records with level 30 (warning level) + warning_messages = {x.message for x in log_records if x.levelno == 30} + # Ensure the warnings coincide with the expected ones + assert warning_messages == set(expected_warnings) + + +@pytest.mark.parametrize( + "input_target_profile, input_backends, throws_exception," + "exception_message, output_backends", + [ + [ + "tosa", + ["Vela"], + True, + "Vela backend not supported with target-profile tosa.", + None, + ], + [ + "tosa", + ["Corstone-300, Vela"], + True, + "Corstone-300, Vela backend not supported with target-profile tosa.", + None, + ], + [ + "cortex-a", + ["Corstone-310", "tosa-checker"], + True, + "Corstone-310, tosa-checker backend not supported " + "with target-profile cortex-a.", + None, + ], + [ + "ethos-u55-256", + ["tosa-checker", "Corstone-310"], + True, + "tosa-checker backend not supported with target-profile ethos-u55-256.", + None, + ], + ["tosa", None, False, None, ["tosa-checker"]], + ["cortex-a", None, False, None, ["armnn-tflitedelegate"]], + ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]], + ["cortex-a", ["armnn-tflitedelegate"], False, None, ["armnn-tflitedelegate"]], + [ + "ethos-u55-256", + ["Vela", "Corstone-300"], + False, + None, + ["Vela", "Corstone-300"], + ], + [ + "ethos-u55-256", + None, + False, + None, + ["Vela", "Corstone-300"], + ], + ], +) +def test_validate_backend( + monkeypatch: pytest.MonkeyPatch, + input_target_profile: str, + input_backends: list[str] | None, + throws_exception: bool, + exception_message: str, + output_backends: list[str] | None, +) -> None: + """Test backend validation with target-profiles and backends.""" + monkeypatch.setattr( + "mlia.cli.config.get_available_backends", + MagicMock(return_value=["Vela", "Corstone-300"]), + ) + + if throws_exception: + with pytest.raises(argparse.ArgumentError) as err: + validate_backend(input_target_profile, input_backends) + assert str(err.value.message) == exception_message + return + + assert validate_backend(input_target_profile, input_backends) == output_backends diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index aed5c42..03ee9d2 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for cli.commands module.""" from __future__ import annotations @@ -14,9 +14,8 @@ from mlia.backend.manager import DefaultInstallationManager from mlia.cli.commands import backend_install from mlia.cli.commands import backend_list from mlia.cli.commands import backend_uninstall -from mlia.cli.commands import operators -from mlia.cli.commands import optimization -from mlia.cli.commands import performance +from mlia.cli.commands import check +from mlia.cli.commands import optimize from mlia.core.context import ExecutionContext from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import MemoryUsage @@ -27,7 +26,7 @@ from mlia.target.ethos_u.performance import PerformanceMetrics def test_operators_expected_parameters(sample_context: ExecutionContext) -> None: """Test operators command wrong parameters.""" with pytest.raises(Exception, match="Model is not provided"): - operators(sample_context, "ethos-u55-256") + check(sample_context, "ethos-u55-256") def test_performance_unknown_target( @@ -35,93 +34,45 @@ def test_performance_unknown_target( ) -> None: """Test that command should fail if unknown target passed.""" with pytest.raises(Exception, match="Unable to find target profile unknown"): - performance( - sample_context, model=str(test_tflite_model), target_profile="unknown" + check( + sample_context, + model=str(test_tflite_model), + target_profile="unknown", + performance=True, ) @pytest.mark.parametrize( - "target_profile, optimization_type, optimization_target, expected_error", + "target_profile, pruning, clustering, pruning_target, clustering_target", [ - [ - "ethos-u55-256", - None, - "0.5", - pytest.raises(Exception, match="Optimization type is not provided"), - ], - [ - "ethos-u65-512", - "unknown", - "16", - pytest.raises(Exception, match="Unsupported optimization type: unknown"), - ], - [ - "ethos-u55-256", - "pruning", - None, - pytest.raises(Exception, match="Optimization target is not provided"), - ], - [ - "ethos-u65-512", - "clustering", - None, - pytest.raises(Exception, match="Optimization target is not provided"), - ], - [ - "unknown", - "clustering", - "16", - pytest.raises(Exception, match="Unable to find target profile unknown"), - ], - ], -) -def test_opt_expected_parameters( - sample_context: ExecutionContext, - target_profile: str, - monkeypatch: pytest.MonkeyPatch, - optimization_type: str, - optimization_target: str, - expected_error: Any, - test_keras_model: Path, -) -> None: - """Test that command should fail if no or unknown optimization type provided.""" - mock_performance_estimation(monkeypatch) - - with expected_error: - optimization( - ctx=sample_context, - target_profile=target_profile, - model=str(test_keras_model), - optimization_type=optimization_type, - optimization_target=optimization_target, - ) - - -@pytest.mark.parametrize( - "target_profile, optimization_type, optimization_target", - [ - ["ethos-u55-256", "pruning", "0.5"], - ["ethos-u65-512", "clustering", "32"], - ["ethos-u55-256", "pruning,clustering", "0.5,32"], + ["ethos-u55-256", True, False, 0.5, None], + ["ethos-u65-512", False, True, 0.5, 32], + ["ethos-u55-256", True, True, 0.5, None], + ["ethos-u55-256", False, False, 0.5, None], + ["ethos-u55-256", False, True, "invalid", 32], ], ) def test_opt_valid_optimization_target( target_profile: str, sample_context: ExecutionContext, - optimization_type: str, - optimization_target: str, + pruning: bool, + clustering: bool, + pruning_target: float | None, + clustering_target: int | None, monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, ) -> None: """Test that command should not fail with valid optimization targets.""" mock_performance_estimation(monkeypatch) - optimization( + optimize( ctx=sample_context, target_profile=target_profile, model=str(test_keras_model), - optimization_type=optimization_type, - optimization_target=optimization_target, + pruning=pruning, + clustering=clustering, + pruning_target=pruning_target, + clustering_target=clustering_target, ) diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py index 1a7cb3f..b007052 100644 --- a/tests/test_cli_config.py +++ b/tests/test_cli_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for cli.config module.""" from __future__ import annotations @@ -7,7 +7,7 @@ from unittest.mock import MagicMock import pytest -from mlia.cli.config import get_default_backends +from mlia.cli.config import get_ethos_u_default_backends from mlia.cli.config import is_corstone_backend @@ -29,7 +29,7 @@ from mlia.cli.config import is_corstone_backend ], ], ) -def test_get_default_backends( +def test_get_ethos_u_default_backends( monkeypatch: pytest.MonkeyPatch, available_backends: list[str], expected_default_backends: list[str], @@ -40,7 +40,7 @@ def test_get_default_backends( MagicMock(return_value=available_backends), ) - assert get_default_backends() == expected_default_backends + assert get_ethos_u_default_backends() == expected_default_backends def test_is_corstone_backend() -> None: diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index c8aeebe..8f7e4b0 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the helper classes.""" from __future__ import annotations @@ -28,40 +28,39 @@ class TestCliActionResolver: {}, [ "Note: you will need a Keras model for that.", - "For example: mlia optimization --optimization-type " - "pruning,clustering --optimization-target 0.5,32 " - "/path/to/keras_model", - "For more info: mlia optimization --help", + "For example: mlia optimize /path/to/keras_model " + "--pruning --clustering " + "--pruning-target 0.5 --clustering-target 32", + "For more info: mlia optimize --help", ], ], [ {"model": "model.h5"}, {}, [ - "For example: mlia optimization --optimization-type " - "pruning,clustering --optimization-target 0.5,32 model.h5", - "For more info: mlia optimization --help", + "For example: mlia optimize model.h5 --pruning --clustering " + "--pruning-target 0.5 --clustering-target 32", + "For more info: mlia optimize --help", ], ], [ {"model": "model.h5"}, {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]}, [ - "For more info: mlia optimization --help", + "For more info: mlia optimize --help", "Optimization command: " - "mlia optimization --optimization-type pruning " - "--optimization-target 0.5 model.h5", + "mlia optimize model.h5 --pruning " + "--pruning-target 0.5", ], ], [ {"model": "model.h5", "target_profile": "target_profile"}, {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]}, [ - "For more info: mlia optimization --help", + "For more info: mlia optimize --help", "Optimization command: " - "mlia optimization --optimization-type pruning " - "--optimization-target 0.5 " - "--target-profile target_profile model.h5", + "mlia optimize model.h5 --target-profile target_profile " + "--pruning --pruning-target 0.5", ], ], ], @@ -76,20 +75,11 @@ class TestCliActionResolver: assert resolver.apply_optimizations(**params) == expected_result @staticmethod - def test_supported_operators_info() -> None: - """Test supported operators info.""" - resolver = CLIActionResolver({}) - assert resolver.supported_operators_info() == [ - "For guidance on supported operators, run: mlia operators " - "--supported-ops-report", - ] - - @staticmethod def test_operator_compatibility_details() -> None: """Test operator compatibility details info.""" resolver = CLIActionResolver({}) assert resolver.operator_compatibility_details() == [ - "For more details, run: mlia operators --help" + "For more details, run: mlia check --help" ] @staticmethod @@ -97,7 +87,7 @@ class TestCliActionResolver: """Test optimization details info.""" resolver = CLIActionResolver({}) assert resolver.optimization_details() == [ - "For more info, see: mlia optimization --help" + "For more info, see: mlia optimize --help" ] @staticmethod @@ -109,19 +99,12 @@ class TestCliActionResolver: [], ], [ - {"model": "model.tflite"}, - [ - "Check the estimated performance by running the " - "following command: ", - "mlia performance model.tflite", - ], - ], - [ {"model": "model.tflite", "target_profile": "target_profile"}, [ "Check the estimated performance by running the " "following command: ", - "mlia performance --target-profile target_profile model.tflite", + "mlia check model.tflite " + "--target-profile target_profile --performance", ], ], ], @@ -142,17 +125,10 @@ class TestCliActionResolver: [], ], [ - {"model": "model.tflite"}, - [ - "Try running the following command to verify that:", - "mlia operators model.tflite", - ], - ], - [ {"model": "model.tflite", "target_profile": "target_profile"}, [ "Try running the following command to verify that:", - "mlia operators --target-profile target_profile model.tflite", + "mlia check model.tflite --target-profile target_profile", ], ], ], diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 925f1e4..5a9c0c9 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for main module.""" from __future__ import annotations @@ -19,7 +19,6 @@ from mlia.backend.errors import BackendUnavailableError from mlia.cli.main import backend_main from mlia.cli.main import CommandInfo from mlia.cli.main import main -from mlia.core.context import ExecutionContext from mlia.core.errors import InternalError from tests.utils.logging import clear_loggers @@ -62,35 +61,23 @@ def test_command_info(is_default: bool, expected_command_help: str) -> None: assert command_info.command_help == expected_command_help -def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: +def test_default_command(monkeypatch: pytest.MonkeyPatch) -> None: """Test adding default command.""" - def mock_command( - func_mock: MagicMock, name: str, with_working_dir: bool - ) -> Callable[..., None]: + def mock_command(func_mock: MagicMock, name: str) -> Callable[..., None]: """Mock cli command.""" def sample_cmd_1(*args: Any, **kwargs: Any) -> None: """Sample command.""" func_mock(*args, **kwargs) - def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None: - """Another sample command.""" - func_mock(ctx=ctx, **kwargs) - - ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1 + ret_func = sample_cmd_1 ret_func.__name__ = name - return ret_func # type: ignore + return ret_func - default_command = MagicMock() non_default_command = MagicMock() - def default_command_params(parser: argparse.ArgumentParser) -> None: - """Add parameters for default command.""" - parser.add_argument("--sample") - parser.add_argument("--default_arg", default="123") - def non_default_command_params(parser: argparse.ArgumentParser) -> None: """Add parameters for non default command.""" parser.add_argument("--param") @@ -100,15 +87,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non MagicMock( return_value=[ CommandInfo( - func=mock_command(default_command, "default_command", True), - aliases=["command1"], - opt_groups=[default_command_params], - is_default=True, - ), - CommandInfo( - func=mock_command( - non_default_command, "non_default_command", False - ), + func=mock_command(non_default_command, "non_default_command"), aliases=["command2"], opt_groups=[non_default_command_params], is_default=False, @@ -117,11 +96,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non ), ) - tmp_working_dir = str(tmp_path) - main(["--working-dir", tmp_working_dir, "--sample", "1"]) main(["command2", "--param", "test"]) - - default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123") non_default_command.assert_called_once_with(param="test") @@ -140,134 +115,168 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: "params, expected_call", [ [ - ["operators", "sample_model.tflite"], + ["check", "sample_model.tflite", "--target-profile", "ethos-u55-256"], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.tflite", + compatibility=False, + performance=False, output=None, - supported_ops_report=False, + json=False, + backend=None, ), ], [ - ["ops", "sample_model.tflite"], - call( - ctx=ANY, - target_profile="ethos-u55-256", - model="sample_model.tflite", - output=None, - supported_ops_report=False, - ), - ], - [ - ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"], + ["check", "sample_model.tflite", "--target-profile", "ethos-u55-128"], call( ctx=ANY, target_profile="ethos-u55-128", model="sample_model.tflite", + compatibility=False, + performance=False, output=None, - supported_ops_report=False, + json=False, + backend=None, ), ], [ - ["operators"], - call( - ctx=ANY, - target_profile="ethos-u55-256", - model=None, - output=None, - supported_ops_report=False, - ), - ], - [ - ["operators", "--supported-ops-report"], + [ + "check", + "sample_model.h5", + "--performance", + "--compatibility", + "--target-profile", + "ethos-u55-256", + ], call( ctx=ANY, target_profile="ethos-u55-256", - model=None, + model="sample_model.h5", output=None, - supported_ops_report=True, + json=False, + compatibility=True, + performance=True, + backend=None, ), ], [ [ - "all_tests", + "check", "sample_model.h5", - "--optimization-type", - "pruning", - "--optimization-target", - "0.5", + "--performance", + "--target-profile", + "ethos-u55-256", + "--output", + "result.json", + "--json", ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - optimization_type="pruning", - optimization_target="0.5", - output=None, - evaluate_on=["Vela"], + performance=True, + compatibility=False, + output=Path("result.json"), + json=True, + backend=None, ), ], [ - ["sample_model.h5"], + [ + "check", + "sample_model.h5", + "--performance", + "--target-profile", + "ethos-u55-128", + ], call( ctx=ANY, - target_profile="ethos-u55-256", + target_profile="ethos-u55-128", model="sample_model.h5", - optimization_type="pruning,clustering", - optimization_target="0.5,32", + compatibility=False, + performance=True, output=None, - evaluate_on=["Vela"], + json=False, + backend=None, ), ], [ - ["performance", "sample_model.h5", "--output", "result.json"], + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--clustering", + ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - output="result.json", - evaluate_on=["Vela"], - ), - ], - [ - ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"], - call( - ctx=ANY, - target_profile="ethos-u55-128", - model="sample_model.h5", + pruning=True, + clustering=True, + pruning_target=None, + clustering_target=None, output=None, - evaluate_on=["Vela"], + json=False, + backend=None, ), ], [ - ["optimization", "sample_model.h5"], + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--clustering", + "--pruning-target", + "0.5", + "--clustering-target", + "32", + ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - optimization_type="pruning,clustering", - optimization_target="0.5,32", + pruning=True, + clustering=True, + pruning_target=0.5, + clustering_target=32, output=None, - evaluate_on=["Vela"], + json=False, + backend=None, ), ], [ - ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"], + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--backend", + "some_backend", + ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - optimization_type="pruning,clustering", - optimization_target="0.5,32", + pruning=True, + clustering=False, + pruning_target=None, + clustering_target=None, output=None, - evaluate_on=["some_backend"], + json=False, + backend=["some_backend"], ), ], [ [ - "operators", + "check", "sample_model.h5", + "--compatibility", "--target-profile", "cortex-a", ], @@ -275,8 +284,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: ctx=ANY, target_profile="cortex-a", model="sample_model.h5", + compatibility=True, + performance=False, output=None, - supported_ops_report=False, + json=False, + backend=None, ), ], ], @@ -288,15 +300,11 @@ def test_commands_execution( mock = MagicMock() monkeypatch.setattr( - "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"]) - ) - - monkeypatch.setattr( "mlia.cli.options.get_available_backends", MagicMock(return_value=["Vela", "some_backend"]), ) - for command in ["all_tests", "operators", "performance", "optimization"]: + for command in ["check", "optimize"]: monkeypatch.setattr( f"mlia.cli.main.{command}", wrap_mock_command(mock, getattr(mlia.cli.main, command)), @@ -335,15 +343,15 @@ def test_commands_execution_backend_main( @pytest.mark.parametrize( - "verbose, exc_mock, expected_output", + "debug, exc_mock, expected_output", [ [ True, MagicMock(side_effect=Exception("Error")), [ "Execution finished with error: Error", - f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} " - "for more details", + "Please check the log files in the /tmp/mlia-", + "/logs for more details", ], ], [ @@ -351,8 +359,8 @@ def test_commands_execution_backend_main( MagicMock(side_effect=Exception("Error")), [ "Execution finished with error: Error", - f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} " - "for more details, or enable verbose mode (--verbose)", + "Please check the log files in the /tmp/mlia-", + "/logs for more details, or enable debug mode (--debug)", ], ], [ @@ -389,18 +397,18 @@ def test_commands_execution_backend_main( ], ], ) -def test_verbose_output( +def test_debug_output( monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture, - verbose: bool, + debug: bool, exc_mock: MagicMock, expected_output: list[str], ) -> None: - """Test flag --verbose.""" + """Test flag --debug.""" def command_params(parser: argparse.ArgumentParser) -> None: """Add parameters for non default command.""" - parser.add_argument("--verbose", action="store_true") + parser.add_argument("--debug", action="store_true") def command() -> None: """Run test command.""" @@ -420,8 +428,8 @@ def test_verbose_output( ) params = ["command"] - if verbose: - params.append("--verbose") + if debug: + params.append("--debug") exit_code = main(params) assert exit_code == 1 diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py index d75f7c0..a889a93 100644 --- a/tests/test_cli_options.py +++ b/tests/test_cli_options.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module options.""" from __future__ import annotations @@ -13,14 +13,19 @@ import pytest from mlia.cli.options import add_output_options from mlia.cli.options import get_target_profile_opts from mlia.cli.options import parse_optimization_parameters +from mlia.cli.options import parse_output_parameters +from mlia.core.common import FormattedFilePath @pytest.mark.parametrize( - "optimization_type, optimization_target, expected_error, expected_result", + "pruning, clustering, pruning_target, clustering_target, expected_error," + "expected_result", [ - ( - "pruning", - "0.5", + [ + False, + False, + None, + None, does_not_raise(), [ dict( @@ -29,39 +34,40 @@ from mlia.cli.options import parse_optimization_parameters layers_to_optimize=None, ) ], - ), - ( - "clustering", - "32", + ], + [ + True, + False, + None, + None, does_not_raise(), [ dict( - optimization_type="clustering", - optimization_target=32.0, + optimization_type="pruning", + optimization_target=0.5, layers_to_optimize=None, ) ], - ), - ( - "pruning,clustering", - "0.5,32", + ], + [ + False, + True, + None, + None, does_not_raise(), [ dict( - optimization_type="pruning", - optimization_target=0.5, - layers_to_optimize=None, - ), - dict( optimization_type="clustering", - optimization_target=32.0, + optimization_target=32, layers_to_optimize=None, - ), + ) ], - ), - ( - "pruning, clustering", - "0.5, 32", + ], + [ + True, + True, + None, + None, does_not_raise(), [ dict( @@ -71,50 +77,66 @@ from mlia.cli.options import parse_optimization_parameters ), dict( optimization_type="clustering", - optimization_target=32.0, + optimization_target=32, layers_to_optimize=None, ), ], - ), - ( - "pruning,clustering", - "0.5", - pytest.raises( - Exception, match="Wrong number of optimization targets and types" - ), - None, - ), - ( - "", - "0.5", - pytest.raises(Exception, match="Optimization type is not provided"), + ], + [ + False, + False, + 0.4, None, - ), - ( - "pruning,clustering", - "", - pytest.raises(Exception, match="Optimization target is not provided"), + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.4, + layers_to_optimize=None, + ) + ], + ], + [ + False, + False, None, - ), - ( - "pruning,", - "0.5,abc", + 32, pytest.raises( - Exception, match="Non numeric value for the optimization target" + argparse.ArgumentError, + match="To enable clustering optimization you need to include " + "the `--clustering` flag in your command.", ), None, - ), + ], + [ + False, + True, + None, + 32.2, + does_not_raise(), + [ + dict( + optimization_type="clustering", + optimization_target=32.2, + layers_to_optimize=None, + ) + ], + ], ], ) def test_parse_optimization_parameters( - optimization_type: str, - optimization_target: str, + pruning: bool, + clustering: bool, + pruning_target: float | None, + clustering_target: int | None, expected_error: Any, expected_result: Any, ) -> None: """Test function parse_optimization_parameters.""" with expected_error: - result = parse_optimization_parameters(optimization_type, optimization_target) + result = parse_optimization_parameters( + pruning, clustering, pruning_target, clustering_target + ) assert result == expected_result @@ -155,28 +177,41 @@ def test_output_options(output_parameters: list[str], expected_path: str) -> Non add_output_options(parser) args = parser.parse_args(output_parameters) - assert args.output == expected_path + assert str(args.output) == expected_path @pytest.mark.parametrize( - "output_filename", + "path, json, expected_error, output", [ - "report.txt", - "report.TXT", - "report", - "report.pdf", + [ + None, + True, + pytest.raises( + argparse.ArgumentError, + match=r"To enable JSON output you need to specify the output path. " + r"\(e.g. --output out.json --json\)", + ), + None, + ], + [None, False, does_not_raise(), None], + [ + Path("test_path"), + False, + does_not_raise(), + FormattedFilePath(Path("test_path"), "plain_text"), + ], + [ + Path("test_path"), + True, + does_not_raise(), + FormattedFilePath(Path("test_path"), "json"), + ], ], ) -def test_output_options_bad_parameters( - output_filename: str, capsys: pytest.CaptureFixture +def test_parse_output_parameters( + path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None ) -> None: - """Test that args parsing should fail if format is not supported.""" - parser = argparse.ArgumentParser() - add_output_options(parser) - - with pytest.raises(SystemExit): - parser.parse_args(["--output", output_filename]) - - err_output = capsys.readouterr().err - suffix = Path(output_filename).suffix[1:] - assert f"Unsupported format '{suffix}'" in err_output + """Test parsing for output parameters.""" + with expected_error: + formatted_output = parse_output_parameters(path, json) + assert formatted_output == output diff --git a/tests/test_core_advice_generation.py b/tests/test_core_advice_generation.py index 3d985eb..2e0038f 100644 --- a/tests/test_core_advice_generation.py +++ b/tests/test_core_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module advice_generation.""" from __future__ import annotations @@ -35,17 +35,17 @@ def test_advice_generation() -> None: "category, expected_advice", [ [ - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [Advice(["Good advice!"])], ], [ - AdviceCategory.PERFORMANCE, + {AdviceCategory.PERFORMANCE}, [], ], ], ) def test_advice_category_decorator( - category: AdviceCategory, + category: set[AdviceCategory], expected_advice: list[Advice], sample_context: Context, ) -> None: @@ -54,7 +54,7 @@ def test_advice_category_decorator( class SampleAdviceProducer(FactBasedAdviceProducer): """Sample advice producer.""" - @advice_category(AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def produce_advice(self, data_item: DataItem) -> None: """Produce the advice.""" self.add_advice(["Good advice!"]) diff --git a/tests/test_core_context.py b/tests/test_core_context.py index 44eb976..dcdbef3 100644 --- a/tests/test_core_context.py +++ b/tests/test_core_context.py @@ -1,17 +1,53 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the module context.""" +from __future__ import annotations + from pathlib import Path +import pytest + from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.events import DefaultEventPublisher +@pytest.mark.parametrize( + "context_advice_category, expected_enabled_categories", + [ + [ + { + AdviceCategory.COMPATIBILITY, + }, + [AdviceCategory.COMPATIBILITY], + ], + [ + { + AdviceCategory.PERFORMANCE, + }, + [AdviceCategory.PERFORMANCE], + ], + [ + {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE}, + [AdviceCategory.PERFORMANCE, AdviceCategory.COMPATIBILITY], + ], + ], +) +def test_execution_context_category_enabled( + context_advice_category: set[AdviceCategory], + expected_enabled_categories: list[AdviceCategory], +) -> None: + """Test category enabled method of execution context.""" + for category in expected_enabled_categories: + assert ExecutionContext( + advice_category=context_advice_category + ).category_enabled(category) + + def test_execution_context(tmpdir: str) -> None: """Test execution context.""" publisher = DefaultEventPublisher() - category = AdviceCategory.OPERATORS + category = {AdviceCategory.COMPATIBILITY} context = ExecutionContext( advice_category=category, @@ -35,13 +71,13 @@ def test_execution_context(tmpdir: str) -> None: assert str(context) == ( f"ExecutionContext: " f"working_dir={tmpdir}, " - "advice_category=OPERATORS, " + "advice_category={'COMPATIBILITY'}, " "config_parameters={'param': 'value'}, " "verbose=True" ) context_with_default_params = ExecutionContext(working_dir=tmpdir) - assert context_with_default_params.advice_category is AdviceCategory.ALL + assert context_with_default_params.advice_category == {AdviceCategory.COMPATIBILITY} assert context_with_default_params.config_parameters is None assert context_with_default_params.event_handlers is None assert isinstance( @@ -55,7 +91,7 @@ def test_execution_context(tmpdir: str) -> None: expected_str = ( f"ExecutionContext: working_dir={tmpdir}, " - "advice_category=ALL, " + "advice_category={'COMPATIBILITY'}, " "config_parameters=None, " "verbose=False" ) diff --git a/tests/test_core_helpers.py b/tests/test_core_helpers.py index 8577617..03ec3f0 100644 --- a/tests/test_core_helpers.py +++ b/tests/test_core_helpers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the helper classes.""" from mlia.core.helpers import APIActionResolver @@ -10,7 +10,6 @@ def test_api_action_resolver() -> None: # pylint: disable=use-implicit-booleaness-not-comparison assert helper.apply_optimizations() == [] - assert helper.supported_operators_info() == [] assert helper.check_performance() == [] assert helper.check_operator_compatibility() == [] assert helper.operator_compatibility_details() == [] diff --git a/tests/test_core_mixins.py b/tests/test_core_mixins.py index 3834fb3..47ed815 100644 --- a/tests/test_core_mixins.py +++ b/tests/test_core_mixins.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the module mixins.""" import pytest @@ -36,7 +36,7 @@ class TestParameterResolverMixin: self.context = sample_context self.context.update( - advice_category=AdviceCategory.OPERATORS, + advice_category={AdviceCategory.COMPATIBILITY}, event_handlers=[], config_parameters={"section": {"param": 123}}, ) @@ -83,7 +83,7 @@ class TestParameterResolverMixin: """Init sample object.""" self.context = sample_context self.context.update( - advice_category=AdviceCategory.OPERATORS, + advice_category={AdviceCategory.COMPATIBILITY}, event_handlers=[], config_parameters={"section": ["param"]}, ) diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py index feff5cc..7b26173 100644 --- a/tests/test_core_reporting.py +++ b/tests/test_core_reporting.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for reporting module.""" from __future__ import annotations @@ -13,11 +13,8 @@ from mlia.core.reporting import CyclesCell from mlia.core.reporting import Format from mlia.core.reporting import NestedReport from mlia.core.reporting import ReportItem -from mlia.core.reporting import resolve_output_format from mlia.core.reporting import SingleRow from mlia.core.reporting import Table -from mlia.core.typing import OutputFormat -from mlia.core.typing import PathOrFileLike from mlia.utils.console import remove_ascii_codes @@ -338,20 +335,3 @@ Single row example: alias="simple_row_example", ) wrong_single_row.to_plain_text() - - -@pytest.mark.parametrize( - "output, expected_output_format", - [ - [None, "plain_text"], - ["", "plain_text"], - ["some_file", "plain_text"], - ["some_format.some_ext", "plain_text"], - ["output.json", "json"], - ], -) -def test_resolve_output_format( - output: PathOrFileLike | None, expected_output_format: OutputFormat -) -> None: - """Test function resolve_output_format.""" - assert resolve_output_format(output) == expected_output_format diff --git a/tests/test_target_config.py b/tests/test_target_config.py index 66ebed6..48f0a58 100644 --- a/tests/test_target_config.py +++ b/tests/test_target_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend config module.""" from __future__ import annotations @@ -25,7 +25,7 @@ def test_ip_config() -> None: ( (None, False, True), (None, True, True), - (AdviceCategory.OPERATORS, True, True), + (AdviceCategory.COMPATIBILITY, True, True), (AdviceCategory.OPTIMIZATION, True, False), ), ) @@ -42,7 +42,7 @@ def test_target_info( backend_registry.register( "backend", BackendConfiguration( - [AdviceCategory.OPERATORS], + [AdviceCategory.COMPATIBILITY], [System.CURRENT], BackendType.BUILTIN, ), diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py index 6effe4c..1997c52 100644 --- a/tests/test_target_cortex_a_advice_generation.py +++ b/tests/test_target_cortex_a_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for advice generation.""" from __future__ import annotations @@ -31,7 +31,7 @@ BACKEND_INFO = ( [ [ ModelIsNotCortexACompatible(BACKEND_INFO, {"UNSUPPORTED_OP"}, {}), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -61,7 +61,7 @@ BACKEND_INFO = ( ) }, ), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -93,7 +93,7 @@ BACKEND_INFO = ( ], [ ModelIsCortexACompatible(BACKEND_INFO), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -108,7 +108,7 @@ BACKEND_INFO = ( flex_ops=["flex_op1", "flex_op2"], custom_ops=["custom_op1", "custom_op2"], ), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -142,7 +142,7 @@ BACKEND_INFO = ( ], [ ModelIsNotTFLiteCompatible(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -154,7 +154,7 @@ BACKEND_INFO = ( ], [ ModelHasCustomOperators(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -166,7 +166,7 @@ BACKEND_INFO = ( ], [ TFLiteCompatibilityCheckFailed(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -181,7 +181,7 @@ BACKEND_INFO = ( def test_cortex_a_advice_producer( tmpdir: str, input_data: DataItem, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory], expected_advice: list[Advice], ) -> None: """Test Cortex-A advice producer.""" diff --git a/tests/test_target_ethos_u_advice_generation.py b/tests/test_target_ethos_u_advice_generation.py index 1569592..e93eeba 100644 --- a/tests/test_target_ethos_u_advice_generation.py +++ b/tests/test_target_ethos_u_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Ethos-U advice generation.""" from __future__ import annotations @@ -28,7 +28,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff [ [ AllOperatorsSupportedOnNPU(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, APIActionResolver(), [ Advice( @@ -41,7 +41,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ], [ AllOperatorsSupportedOnNPU(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, CLIActionResolver( { "target_profile": "sample_target", @@ -55,15 +55,15 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff "run completely on NPU.", "Check the estimated performance by running the " "following command: ", - "mlia performance --target-profile sample_target " - "sample_model.tflite", + "mlia check sample_model.tflite --target-profile sample_target " + "--performance", ] ) ], ], [ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, APIActionResolver(), [ Advice( @@ -78,7 +78,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ], [ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, CLIActionResolver({}), [ Advice( @@ -87,15 +87,13 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff "OP1,OP2,OP3.", "Using operators that are supported by the NPU will " "improve performance.", - "For guidance on supported operators, run: mlia operators " - "--supported-ops-report", ] ) ], ], [ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, APIActionResolver(), [ Advice( @@ -110,7 +108,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ], [ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, CLIActionResolver({}), [ Advice( @@ -138,7 +136,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -178,7 +176,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, CLIActionResolver({"model": "sample_model.h5"}), [ Advice( @@ -192,10 +190,10 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff "You can try to push the optimization target higher " "(e.g. pruning: 0.6) " "to check if those results can be further improved.", - "For more info: mlia optimization --help", + "For more info: mlia optimize --help", "Optimization command: " - "mlia optimization --optimization-type pruning " - "--optimization-target 0.6 sample_model.h5", + "mlia optimize sample_model.h5 --pruning " + "--pruning-target 0.6", ] ), Advice( @@ -225,7 +223,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -267,7 +265,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -304,7 +302,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -354,7 +352,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [], # no advice for more than one optimization result ], @@ -364,7 +362,7 @@ def test_ethosu_advice_producer( tmpdir: str, input_data: DataItem, expected_advice: list[Advice], - advice_category: AdviceCategory, + advice_category: set[AdviceCategory] | None, action_resolver: ActionResolver, ) -> None: """Test Ethos-U Advice producer.""" @@ -386,17 +384,17 @@ def test_ethosu_advice_producer( "advice_category, action_resolver, expected_advice", [ [ - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE}, None, [], ], [ - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, None, [], ], [ - AdviceCategory.PERFORMANCE, + {AdviceCategory.PERFORMANCE}, APIActionResolver(), [ Advice( @@ -414,31 +412,33 @@ def test_ethosu_advice_producer( ], ], [ - AdviceCategory.PERFORMANCE, - CLIActionResolver({"model": "test_model.h5"}), + {AdviceCategory.PERFORMANCE}, + CLIActionResolver( + {"model": "test_model.h5", "target_profile": "sample_target"} + ), [ Advice( [ "You can improve the inference time by using only operators " "that are supported by the NPU.", "Try running the following command to verify that:", - "mlia operators test_model.h5", + "mlia check test_model.h5 --target-profile sample_target", ] ), Advice( [ "Check if you can improve the performance by applying " "tooling techniques to your model.", - "For example: mlia optimization --optimization-type " - "pruning,clustering --optimization-target 0.5,32 " - "test_model.h5", - "For more info: mlia optimization --help", + "For example: mlia optimize test_model.h5 " + "--pruning --clustering " + "--pruning-target 0.5 --clustering-target 32", + "For more info: mlia optimize --help", ] ), ], ], [ - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -450,14 +450,14 @@ def test_ethosu_advice_producer( ], ], [ - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, CLIActionResolver({"model": "test_model.h5"}), [ Advice( [ "For better performance, make sure that all the operators " "of your final TensorFlow Lite model are supported by the NPU.", - "For more details, run: mlia operators --help", + "For more details, run: mlia check --help", ] ) ], @@ -466,7 +466,7 @@ def test_ethosu_advice_producer( ) def test_ethosu_static_advice_producer( tmpdir: str, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory] | None, action_resolver: ActionResolver, expected_advice: list[Advice], ) -> None: diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py index e6ee296..e6028a9 100644 --- a/tests/test_target_registry.py +++ b/tests/test_target_registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the target registry module.""" from __future__ import annotations @@ -26,11 +26,11 @@ def test_target_registry(expected_target: str) -> None: @pytest.mark.parametrize( ("target_name", "expected_advices"), ( - ("Cortex-A", [AdviceCategory.OPERATORS]), + ("Cortex-A", [AdviceCategory.COMPATIBILITY]), ( "Ethos-U55", [ - AdviceCategory.OPERATORS, + AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION, AdviceCategory.PERFORMANCE, ], @@ -38,12 +38,12 @@ def test_target_registry(expected_target: str) -> None: ( "Ethos-U65", [ - AdviceCategory.OPERATORS, + AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION, AdviceCategory.PERFORMANCE, ], ), - ("TOSA", [AdviceCategory.OPERATORS]), + ("TOSA", [AdviceCategory.COMPATIBILITY]), ), ) def test_supported_advice( @@ -72,7 +72,7 @@ def test_supported_backends(target_name: str, expected_backends: list[str]) -> N @pytest.mark.parametrize( ("advice", "expected_targets"), ( - (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]), + (AdviceCategory.COMPATIBILITY, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]), (AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]), (AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]), ), diff --git a/tests/test_target_tosa_advice_generation.py b/tests/test_target_tosa_advice_generation.py index e8e06f8..d5ebbd7 100644 --- a/tests/test_target_tosa_advice_generation.py +++ b/tests/test_target_tosa_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for advice generation.""" from __future__ import annotations @@ -19,7 +19,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible [ [ ModelIsNotTOSACompatible(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -31,7 +31,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible ], [ ModelIsTOSACompatible(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [Advice(["Model is fully TOSA compatible."])], ], ], @@ -39,7 +39,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible def test_tosa_advice_producer( tmpdir: str, input_data: DataItem, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory], expected_advice: list[Advice], ) -> None: """Test TOSA advice producer.""" |