diff options
Diffstat (limited to 'tests/test_cli_commands.py')
-rw-r--r-- | tests/test_cli_commands.py | 97 |
1 files changed, 24 insertions, 73 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index aed5c42..03ee9d2 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for cli.commands module.""" from __future__ import annotations @@ -14,9 +14,8 @@ from mlia.backend.manager import DefaultInstallationManager from mlia.cli.commands import backend_install from mlia.cli.commands import backend_list from mlia.cli.commands import backend_uninstall -from mlia.cli.commands import operators -from mlia.cli.commands import optimization -from mlia.cli.commands import performance +from mlia.cli.commands import check +from mlia.cli.commands import optimize from mlia.core.context import ExecutionContext from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import MemoryUsage @@ -27,7 +26,7 @@ from mlia.target.ethos_u.performance import PerformanceMetrics def test_operators_expected_parameters(sample_context: ExecutionContext) -> None: """Test operators command wrong parameters.""" with pytest.raises(Exception, match="Model is not provided"): - operators(sample_context, "ethos-u55-256") + check(sample_context, "ethos-u55-256") def test_performance_unknown_target( @@ -35,93 +34,45 @@ def test_performance_unknown_target( ) -> None: """Test that command should fail if unknown target passed.""" with pytest.raises(Exception, match="Unable to find target profile unknown"): - performance( - sample_context, model=str(test_tflite_model), target_profile="unknown" + check( + sample_context, + model=str(test_tflite_model), + target_profile="unknown", + performance=True, ) @pytest.mark.parametrize( - "target_profile, optimization_type, optimization_target, expected_error", + "target_profile, pruning, clustering, pruning_target, clustering_target", [ - [ - "ethos-u55-256", - None, - "0.5", - pytest.raises(Exception, match="Optimization type is not provided"), - ], - [ - "ethos-u65-512", - "unknown", - "16", - pytest.raises(Exception, match="Unsupported optimization type: unknown"), - ], - [ - "ethos-u55-256", - "pruning", - None, - pytest.raises(Exception, match="Optimization target is not provided"), - ], - [ - "ethos-u65-512", - "clustering", - None, - pytest.raises(Exception, match="Optimization target is not provided"), - ], - [ - "unknown", - "clustering", - "16", - pytest.raises(Exception, match="Unable to find target profile unknown"), - ], - ], -) -def test_opt_expected_parameters( - sample_context: ExecutionContext, - target_profile: str, - monkeypatch: pytest.MonkeyPatch, - optimization_type: str, - optimization_target: str, - expected_error: Any, - test_keras_model: Path, -) -> None: - """Test that command should fail if no or unknown optimization type provided.""" - mock_performance_estimation(monkeypatch) - - with expected_error: - optimization( - ctx=sample_context, - target_profile=target_profile, - model=str(test_keras_model), - optimization_type=optimization_type, - optimization_target=optimization_target, - ) - - -@pytest.mark.parametrize( - "target_profile, optimization_type, optimization_target", - [ - ["ethos-u55-256", "pruning", "0.5"], - ["ethos-u65-512", "clustering", "32"], - ["ethos-u55-256", "pruning,clustering", "0.5,32"], + ["ethos-u55-256", True, False, 0.5, None], + ["ethos-u65-512", False, True, 0.5, 32], + ["ethos-u55-256", True, True, 0.5, None], + ["ethos-u55-256", False, False, 0.5, None], + ["ethos-u55-256", False, True, "invalid", 32], ], ) def test_opt_valid_optimization_target( target_profile: str, sample_context: ExecutionContext, - optimization_type: str, - optimization_target: str, + pruning: bool, + clustering: bool, + pruning_target: float | None, + clustering_target: int | None, monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, ) -> None: """Test that command should not fail with valid optimization targets.""" mock_performance_estimation(monkeypatch) - optimization( + optimize( ctx=sample_context, target_profile=target_profile, model=str(test_keras_model), - optimization_type=optimization_type, - optimization_target=optimization_target, + pruning=pruning, + clustering=clustering, + pruning_target=pruning_target, + clustering_target=clustering_target, ) |