aboutsummaryrefslogtreecommitdiff
path: root/tests/test_cli_commands.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cli_commands.py')
-rw-r--r--tests/test_cli_commands.py97
1 files changed, 24 insertions, 73 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index aed5c42..03ee9d2 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.commands module."""
from __future__ import annotations
@@ -14,9 +14,8 @@ from mlia.backend.manager import DefaultInstallationManager
from mlia.cli.commands import backend_install
from mlia.cli.commands import backend_list
from mlia.cli.commands import backend_uninstall
-from mlia.cli.commands import operators
-from mlia.cli.commands import optimization
-from mlia.cli.commands import performance
+from mlia.cli.commands import check
+from mlia.cli.commands import optimize
from mlia.core.context import ExecutionContext
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import MemoryUsage
@@ -27,7 +26,7 @@ from mlia.target.ethos_u.performance import PerformanceMetrics
def test_operators_expected_parameters(sample_context: ExecutionContext) -> None:
"""Test operators command wrong parameters."""
with pytest.raises(Exception, match="Model is not provided"):
- operators(sample_context, "ethos-u55-256")
+ check(sample_context, "ethos-u55-256")
def test_performance_unknown_target(
@@ -35,93 +34,45 @@ def test_performance_unknown_target(
) -> None:
"""Test that command should fail if unknown target passed."""
with pytest.raises(Exception, match="Unable to find target profile unknown"):
- performance(
- sample_context, model=str(test_tflite_model), target_profile="unknown"
+ check(
+ sample_context,
+ model=str(test_tflite_model),
+ target_profile="unknown",
+ performance=True,
)
@pytest.mark.parametrize(
- "target_profile, optimization_type, optimization_target, expected_error",
+ "target_profile, pruning, clustering, pruning_target, clustering_target",
[
- [
- "ethos-u55-256",
- None,
- "0.5",
- pytest.raises(Exception, match="Optimization type is not provided"),
- ],
- [
- "ethos-u65-512",
- "unknown",
- "16",
- pytest.raises(Exception, match="Unsupported optimization type: unknown"),
- ],
- [
- "ethos-u55-256",
- "pruning",
- None,
- pytest.raises(Exception, match="Optimization target is not provided"),
- ],
- [
- "ethos-u65-512",
- "clustering",
- None,
- pytest.raises(Exception, match="Optimization target is not provided"),
- ],
- [
- "unknown",
- "clustering",
- "16",
- pytest.raises(Exception, match="Unable to find target profile unknown"),
- ],
- ],
-)
-def test_opt_expected_parameters(
- sample_context: ExecutionContext,
- target_profile: str,
- monkeypatch: pytest.MonkeyPatch,
- optimization_type: str,
- optimization_target: str,
- expected_error: Any,
- test_keras_model: Path,
-) -> None:
- """Test that command should fail if no or unknown optimization type provided."""
- mock_performance_estimation(monkeypatch)
-
- with expected_error:
- optimization(
- ctx=sample_context,
- target_profile=target_profile,
- model=str(test_keras_model),
- optimization_type=optimization_type,
- optimization_target=optimization_target,
- )
-
-
-@pytest.mark.parametrize(
- "target_profile, optimization_type, optimization_target",
- [
- ["ethos-u55-256", "pruning", "0.5"],
- ["ethos-u65-512", "clustering", "32"],
- ["ethos-u55-256", "pruning,clustering", "0.5,32"],
+ ["ethos-u55-256", True, False, 0.5, None],
+ ["ethos-u65-512", False, True, 0.5, 32],
+ ["ethos-u55-256", True, True, 0.5, None],
+ ["ethos-u55-256", False, False, 0.5, None],
+ ["ethos-u55-256", False, True, "invalid", 32],
],
)
def test_opt_valid_optimization_target(
target_profile: str,
sample_context: ExecutionContext,
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
monkeypatch: pytest.MonkeyPatch,
test_keras_model: Path,
) -> None:
"""Test that command should not fail with valid optimization targets."""
mock_performance_estimation(monkeypatch)
- optimization(
+ optimize(
ctx=sample_context,
target_profile=target_profile,
model=str(test_keras_model),
- optimization_type=optimization_type,
- optimization_target=optimization_target,
+ pruning=pruning,
+ clustering=clustering,
+ pruning_target=pruning_target,
+ clustering_target=clustering_target,
)