aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-02-15 14:50:58 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-14 15:45:40 +0000
commit0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch)
tree09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /tests
parent09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff)
downloadmlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization Resolves: MLIA-1004 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py18
-rw-r--r--tests/test_cli_commands.py14
-rw-r--r--tests/test_cli_helpers.py36
-rw-r--r--tests/test_cli_main.py34
-rw-r--r--tests/test_common_optimization.py106
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py32
-rw-r--r--tests/test_nn_select.py69
-rw-r--r--tests/test_target_config.py16
-rw-r--r--tests/test_target_cortex_a_advisor.py5
-rw-r--r--tests/test_target_registry.py29
-rw-r--r--tests/test_target_tosa_advisor.py5
-rw-r--r--tests/test_utils_filesystem.py8
12 files changed, 337 insertions, 35 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 1092979..53bfb0c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@ from typing import Callable
from typing import Generator
from unittest.mock import MagicMock
+import _pytest
import numpy as np
import pytest
import tensorflow as tf
@@ -256,11 +257,16 @@ def fixture_test_tfrecord_fp32(
@pytest.fixture(scope="session", autouse=True)
-def set_training_steps() -> Generator[None, None, None]:
+def set_training_steps(
+ request: _pytest.fixtures.SubRequest,
+) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- with pytest.MonkeyPatch.context() as monkeypatch:
- monkeypatch.setattr(
- "mlia.nn.select._get_rewrite_train_params",
- MagicMock(return_value=MockTrainingParameters()),
- )
+ if "set_training_steps" == request.fixturename:
yield
+ else:
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_params",
+ MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ )
+ yield
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 1ce793f..1a9bbb8 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -52,13 +52,15 @@ def test_performance_unknown_target(
@pytest.mark.parametrize(
- "target_profile, pruning, clustering, pruning_target, clustering_target, "
- "rewrite, rewrite_target, rewrite_start, rewrite_end, expected_error",
+ "target_profile, pruning, clustering, optimization_profile, pruning_target, "
+ "clustering_target, rewrite, rewrite_target, rewrite_start, rewrite_end ,"
+ "expected_error",
[
[
"ethos-u55-256",
True,
False,
+ None,
0.5,
None,
False,
@@ -73,6 +75,7 @@ def test_performance_unknown_target(
False,
None,
None,
+ None,
True,
"fully_connected",
"sequential/flatten/Reshape",
@@ -83,6 +86,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
True,
False,
+ None,
0.5,
None,
True,
@@ -98,6 +102,7 @@ def test_performance_unknown_target(
"ethos-u65-512",
False,
True,
+ None,
0.5,
32,
False,
@@ -110,6 +115,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
False,
False,
+ None,
0.5,
None,
True,
@@ -128,6 +134,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
False,
False,
+ None,
0.5,
None,
True,
@@ -146,6 +153,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
False,
False,
+ None,
"invalid",
None,
True,
@@ -169,6 +177,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m
clustering: bool,
pruning_target: float | None,
clustering_target: int | None,
+ optimization_profile: str | None,
rewrite: bool,
rewrite_target: str | None,
rewrite_start: str | None,
@@ -192,6 +201,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m
model=str(model_type),
pruning=pruning,
clustering=clustering,
+ optimization_profile=optimization_profile,
pruning_target=pruning_target,
clustering_target=clustering_target,
rewrite=rewrite,
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 494ed89..0e9f0d6 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -1,8 +1,9 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
from __future__ import annotations
+import re
from pathlib import Path
from typing import Any
@@ -144,9 +145,38 @@ class TestCliActionResolver:
def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None:
- """Test if the profile file is copied into the output directory."""
+ """Test if the target profile file is copied into the output directory."""
test_target_profile_name = "ethos-u55-128"
test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
- copy_profile_file_to_output_dir(test_target_profile_name, tmp_path)
+ copy_profile_file_to_output_dir(
+ test_target_profile_name, tmp_path, profile_to_copy="target_profile"
+ )
assert Path.is_file(test_file_path)
+
+
+def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None:
+ """Test if the optimization profile file is copied into the output directory."""
+ test_target_profile_name = "optimization"
+ test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
+
+ copy_profile_file_to_output_dir(
+ test_target_profile_name, tmp_path, profile_to_copy="optimization_profile"
+ )
+ assert Path.is_file(test_file_path)
+
+
+def test_copy_optimization_file_to_output_dir_error(tmp_path: Path) -> None:
+ """Test that the correct error is raised if the optimization
+ profile cannot be found."""
+ test_target_profile_name = "wrong_file"
+ with pytest.raises(
+ RuntimeError,
+ match=re.escape(
+ "Failed to copy optimization_profile file: "
+ "[Errno 2] No such file or directory: '" + test_target_profile_name + "'"
+ ),
+ ):
+ copy_profile_file_to_output_dir(
+ test_target_profile_name, tmp_path, profile_to_copy="optimization_profile"
+ )
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index e415284..564886b 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for main module."""
from __future__ import annotations
@@ -164,6 +164,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=None,
clustering_target=None,
+ optimization_profile="optimization",
backend=None,
rewrite=False,
rewrite_target=None,
@@ -194,6 +195,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
pruning_target=0.5,
clustering_target=32,
backend=None,
+ optimization_profile="optimization",
rewrite=False,
rewrite_target=None,
rewrite_start=None,
@@ -219,6 +221,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=False,
pruning_target=None,
clustering_target=None,
+ optimization_profile="optimization",
backend=["some_backend"],
rewrite=False,
rewrite_target=None,
@@ -244,6 +247,35 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
backend=None,
),
],
+ [
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--backend",
+ "some_backend",
+ "--optimization-profile",
+ "optimization",
+ ],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ pruning=True,
+ clustering=False,
+ pruning_target=None,
+ clustering_target=None,
+ backend=["some_backend"],
+ optimization_profile="optimization",
+ rewrite=False,
+ rewrite_target=None,
+ rewrite_start=None,
+ rewrite_end=None,
+ dataset=None,
+ ),
+ ],
],
)
def test_commands_execution(
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 599610d..05a5b55 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,15 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from contextlib import ExitStack as does_not_raises
from pathlib import Path
+from typing import Any
from unittest.mock import MagicMock
import pytest
from mlia.core.context import ExecutionContext
from mlia.nn.common import Optimizer
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
+from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -46,8 +52,14 @@ def test_optimizing_data_collector(
{"optimization_type": "fake", "optimization_target": 42},
]
]
+ training_parameters = {"batch_size": 32, "show_progress": False}
context = ExecutionContext(
- config_parameters={"common_optimizations": {"optimizations": optimizations}}
+ config_parameters={
+ "common_optimizations": {
+ "optimizations": optimizations,
+ "training_parameters": [training_parameters],
+ }
+ }
)
target_profile = MagicMock(spec=TargetProfile)
@@ -61,7 +73,95 @@ def test_optimizing_data_collector(
collector = OptimizingDataCollector(test_keras_model, target_profile)
+ optimize_model_mock = MagicMock(side_effect=collector.optimize_model)
+ monkeypatch.setattr(
+ "mlia.target.common.optimization.OptimizingDataCollector.optimize_model",
+ optimize_model_mock,
+ )
+ opt_settings = [
+ [
+ OptimizationSettings(
+ item.get("optimization_type"), # type: ignore
+ item.get("optimization_target"), # type: ignore
+ item.get("layers_to_optimize"), # type: ignore
+ item.get("dataset"), # type: ignore
+ )
+ for item in opt_configuration
+ ]
+ for opt_configuration in optimizations
+ ]
+
collector.set_context(context)
collector.collect_data()
-
+ assert optimize_model_mock.call_args.args[0] == opt_settings[0]
+ assert optimize_model_mock.call_args.args[1] == [training_parameters]
assert fake_optimizer.invocation_count == 1
+
+
+@pytest.mark.parametrize(
+ "extra_args, error_to_raise",
+ [
+ (
+ {
+ "optimization_targets": [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ }
+ ],
+ },
+ does_not_raises(),
+ ),
+ (
+ {
+ "optimization_profile": load_profile(
+ "src/mlia/resources/optimization_profiles/optimization.toml"
+ )
+ },
+ does_not_raises(),
+ ),
+ (
+ {
+ "optimization_targets": {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ },
+ pytest.raises(
+ TypeError, match="Optimization targets value has wrong format."
+ ),
+ ),
+ (
+ {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]},
+ pytest.raises(
+ TypeError, match="Training Parameter values has wrong format."
+ ),
+ ),
+ ],
+)
+def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None:
+ """Test to check that optimization_targets and optimization_profiles are
+ correctly parsed."""
+ advisor_parameters: dict = {}
+
+ with error_to_raise:
+ add_common_optimization_params(advisor_parameters, extra_args)
+ if not extra_args.get("optimization_targets"):
+ assert advisor_parameters["common_optimizations"]["optimizations"] == [
+ _DEFAULT_OPTIMIZATION_TARGETS
+ ]
+ else:
+ assert advisor_parameters["common_optimizations"]["optimizations"] == [
+ extra_args["optimization_targets"]
+ ]
+
+ if not extra_args.get("optimization_profile"):
+ assert advisor_parameters["common_optimizations"][
+ "training_parameters"
+ ] == [None]
+ else:
+ assert advisor_parameters["common_optimizations"][
+ "training_parameters"
+ ] == list(extra_args["optimization_profile"].values())
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 487784d..363d614 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.rewrite."""
from __future__ import annotations
@@ -7,6 +7,7 @@ from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
+from unittest.mock import MagicMock
import pytest
@@ -16,6 +17,8 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
+from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
@@ -129,3 +132,30 @@ def test_rewrite_function_autoload_fail() -> None:
"Unable to load rewrite function 'invalid_module.invalid_function'"
" for 'mock_rewrite'."
)
+
+
+def test_rewrite_configuration_train_params(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test if we pass training parameters to the
+ rewrite configuration function they are passed to train_in_dir."""
+ train_params = TrainingParameters(
+ batch_size=64, steps=24000, learning_rate=1e-5, show_progress=True
+ )
+
+ config_obj = RewriteConfiguration(
+ "fully_connected",
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ test_tfrecord_fp32,
+ train_params=train_params,
+ )
+
+ rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ train_mock = MagicMock(side_effect=train_in_dir)
+ monkeypatch.setattr("mlia.nn.rewrite.core.train.train_in_dir", train_mock)
+ rewriter_obj.apply_optimization()
+
+ train_mock.assert_called_once()
+ assert train_mock.call_args.kwargs["train_params"] == train_params
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index 31628d2..92b7a3d 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -1,16 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module select."""
from __future__ import annotations
from contextlib import ExitStack as does_not_raise
+from dataclasses import asdict
from pathlib import Path
from typing import Any
+from typing import cast
import pytest
import tensorflow as tf
from mlia.core.errors import ConfigurationError
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.select import get_optimizer
from mlia.nn.select import MultiStageOptimizer
from mlia.nn.select import OptimizationSettings
@@ -135,6 +140,23 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
MultiStageOptimizer,
"pruning: 0.5 - clustering: 32",
),
+ (
+ OptimizationSettings(
+ optimization_type="rewrite",
+ optimization_target="fully_connected", # type: ignore
+ layers_to_optimize=None,
+ dataset=None,
+ ),
+ does_not_raise(),
+ RewritingOptimizer,
+ "rewrite: fully_connected",
+ ),
+ (
+ RewriteConfiguration("fully_connected"),
+ does_not_raise(),
+ RewritingOptimizer,
+ "rewrite: fully_connected",
+ ),
],
)
def test_get_optimizer(
@@ -143,17 +165,58 @@ def test_get_optimizer(
expected_type: type,
expected_config: str,
test_keras_model: Path,
+ test_tflite_model: Path,
) -> None:
"""Test function get_optimzer."""
- model = tf.keras.models.load_model(str(test_keras_model))
-
with expected_error:
+ if (
+ isinstance(config, OptimizationSettings)
+ and config.optimization_type == "rewrite"
+ ) or isinstance(config, RewriteConfiguration):
+ model = test_tflite_model
+ else:
+ model = tf.keras.models.load_model(str(test_keras_model))
optimizer = get_optimizer(model, config)
assert isinstance(optimizer, expected_type)
assert optimizer.optimization_config() == expected_config
@pytest.mark.parametrize(
+ "rewrite_parameters",
+ [[None], [{"batch_size": 64, "learning_rate": 0.003}]],
+)
+@pytest.mark.skip_set_training_steps
+def test_get_optimizer_training_parameters(
+ rewrite_parameters: list[dict], test_tflite_model: Path
+) -> None:
+ """Test function get_optimzer with various combinations of parameters."""
+ config = OptimizationSettings(
+ optimization_type="rewrite",
+ optimization_target="fully_connected", # type: ignore
+ layers_to_optimize=None,
+ dataset=None,
+ )
+ optimizer = cast(
+ RewritingOptimizer,
+ get_optimizer(test_tflite_model, config, list(rewrite_parameters)),
+ )
+
+ assert len(rewrite_parameters) == 1
+
+ assert isinstance(
+ optimizer.optimizer_configuration.train_params, TrainingParameters
+ )
+ if not rewrite_parameters[0]:
+ assert asdict(TrainingParameters()) == asdict(
+ optimizer.optimizer_configuration.train_params
+ )
+ else:
+ assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict(
+ optimizer.optimizer_configuration.train_params
+ )
+
+
+@pytest.mark.parametrize(
"params, expected_result",
[
(
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
index 8055af0..56e9f11 100644
--- a/tests/test_target_config.py
+++ b/tests/test_target_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend config module."""
from __future__ import annotations
@@ -10,9 +10,9 @@ from mlia.backend.config import BackendType
from mlia.backend.config import System
from mlia.core.common import AdviceCategory
from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
-from mlia.target.config import get_builtin_profile_path
from mlia.target.config import get_builtin_supported_profile_names
-from mlia.target.config import is_builtin_profile
+from mlia.target.config import get_builtin_target_profile_path
+from mlia.target.config import is_builtin_target_profile
from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
from mlia.target.config import TargetProfile
@@ -33,23 +33,23 @@ def test_builtin_supported_profile_names() -> None:
"tosa",
]
for profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES:
- assert is_builtin_profile(profile_name)
- profile_file = get_builtin_profile_path(profile_name)
+ assert is_builtin_target_profile(profile_name)
+ profile_file = get_builtin_target_profile_path(profile_name)
assert profile_file.is_file()
def test_builtin_profile_files() -> None:
"""Test function 'get_bulitin_profile_file'."""
- profile_file = get_builtin_profile_path("cortex-a")
+ profile_file = get_builtin_target_profile_path("cortex-a")
assert profile_file.is_file()
- profile_file = get_builtin_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST")
+ profile_file = get_builtin_target_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST")
assert not profile_file.exists()
def test_load_profile() -> None:
"""Test getting profile data."""
- profile_file = get_builtin_profile_path("ethos-u55-256")
+ profile_file = get_builtin_target_profile_path("ethos-u55-256")
assert load_profile(profile_file) == {
"target": "ethos-u55",
"mac": 256,
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 6e370d6..59d54b5 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Cortex-A MLIA module."""
from pathlib import Path
@@ -46,7 +46,8 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
"optimization_type": "clustering",
},
]
- ]
+ ],
+ "training_parameters": [None],
},
}
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
index ca1ad82..120d0f5 100644
--- a/tests/test_target_registry.py
+++ b/tests/test_target_registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the target registry module."""
from __future__ import annotations
@@ -6,9 +6,11 @@ from __future__ import annotations
import pytest
from mlia.core.common import AdviceCategory
-from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import get_builtin_optimization_profile_path
+from mlia.target.config import get_builtin_target_profile_path
from mlia.target.registry import all_supported_backends
from mlia.target.registry import default_backends
+from mlia.target.registry import get_optimization_profile
from mlia.target.registry import is_supported
from mlia.target.registry import profile
from mlia.target.registry import registry
@@ -146,6 +148,27 @@ def test_profile(target_profile: str) -> None:
assert target_profile.startswith(cfg.target)
# Test loading the file directly
- profile_file = get_builtin_profile_path(target_profile)
+ profile_file = get_builtin_target_profile_path(target_profile)
cfg = profile(profile_file)
assert target_profile.startswith(cfg.target)
+
+
+@pytest.mark.parametrize("optimization_profile", ["optimization"])
+def test_optimization_profile(optimization_profile: str) -> None:
+ """Test function optimization_profile()."""
+
+ get_optimization_profile(optimization_profile)
+
+ profile_file = get_builtin_optimization_profile_path(optimization_profile)
+ get_optimization_profile(profile_file)
+
+
+@pytest.mark.parametrize("optimization_profile", ["non_valid_file"])
+def test_optimization_profile_non_valid_file(optimization_profile: str) -> None:
+ """Test function optimization_profile()."""
+ with pytest.raises(
+ ValueError,
+ match=f"optimization Profile '{optimization_profile}' is neither "
+ "a valid built-in optimization profile name or a valid file path.",
+ ):
+ get_optimization_profile(optimization_profile)
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index 36e52e9..cc47321 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA advisor."""
from pathlib import Path
@@ -46,7 +46,8 @@ def test_configure_and_get_tosa_advisor(
"optimization_type": "clustering",
},
]
- ]
+ ],
+ "training_parameters": [None],
},
"tosa_inference_advisor": {
"model": str(test_tflite_model),
diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py
index c1c9876..1ccbd1c 100644
--- a/tests/test_utils_filesystem.py
+++ b/tests/test_utils_filesystem.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the filesystem module."""
import contextlib
@@ -10,6 +10,7 @@ from mlia.utils.filesystem import all_files_exist
from mlia.utils.filesystem import all_paths_valid
from mlia.utils.filesystem import copy_all
from mlia.utils.filesystem import get_mlia_resources
+from mlia.utils.filesystem import get_mlia_target_optimization_dir
from mlia.utils.filesystem import get_mlia_target_profiles_dir
from mlia.utils.filesystem import get_vela_config
from mlia.utils.filesystem import recreate_directory
@@ -37,6 +38,11 @@ def test_get_mlia_target_profiles() -> None:
assert get_mlia_target_profiles_dir().is_dir()
+def test_get_mlia_target_optimizations() -> None:
+ """Test target profiles getter."""
+ assert get_mlia_target_optimization_dir().is_dir()
+
+
@pytest.mark.parametrize("raise_exception", [True, False])
def test_temp_file(raise_exception: bool) -> None:
"""Test temp_file context manager."""