diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-02-15 14:50:58 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-14 15:45:40 +0000 |
commit | 0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch) | |
tree | 09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /tests | |
parent | 09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff) | |
download | mlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz |
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization
Resolves: MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 18 | ||||
-rw-r--r-- | tests/test_cli_commands.py | 14 | ||||
-rw-r--r-- | tests/test_cli_helpers.py | 36 | ||||
-rw-r--r-- | tests/test_cli_main.py | 34 | ||||
-rw-r--r-- | tests/test_common_optimization.py | 106 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 32 | ||||
-rw-r--r-- | tests/test_nn_select.py | 69 | ||||
-rw-r--r-- | tests/test_target_config.py | 16 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advisor.py | 5 | ||||
-rw-r--r-- | tests/test_target_registry.py | 29 | ||||
-rw-r--r-- | tests/test_target_tosa_advisor.py | 5 | ||||
-rw-r--r-- | tests/test_utils_filesystem.py | 8 |
12 files changed, 337 insertions, 35 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 1092979..53bfb0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from typing import Callable from typing import Generator from unittest.mock import MagicMock +import _pytest import numpy as np import pytest import tensorflow as tf @@ -256,11 +257,16 @@ def fixture_test_tfrecord_fp32( @pytest.fixture(scope="session", autouse=True) -def set_training_steps() -> Generator[None, None, None]: +def set_training_steps( + request: _pytest.fixtures.SubRequest, +) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" - with pytest.MonkeyPatch.context() as monkeypatch: - monkeypatch.setattr( - "mlia.nn.select._get_rewrite_train_params", - MagicMock(return_value=MockTrainingParameters()), - ) + if "set_training_steps" == request.fixturename: yield + else: + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "mlia.nn.select._get_rewrite_params", + MagicMock(return_value=[MockTrainingParameters(), None, None]), + ) + yield diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 1ce793f..1a9bbb8 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -52,13 +52,15 @@ def test_performance_unknown_target( @pytest.mark.parametrize( - "target_profile, pruning, clustering, pruning_target, clustering_target, " - "rewrite, rewrite_target, rewrite_start, rewrite_end, expected_error", + "target_profile, pruning, clustering, optimization_profile, pruning_target, " + "clustering_target, rewrite, rewrite_target, rewrite_start, rewrite_end ," + "expected_error", [ [ "ethos-u55-256", True, False, + None, 0.5, None, False, @@ -73,6 +75,7 @@ def test_performance_unknown_target( False, None, None, + None, True, "fully_connected", "sequential/flatten/Reshape", @@ -83,6 +86,7 @@ def test_performance_unknown_target( "ethos-u55-256", True, False, + None, 0.5, None, True, @@ -98,6 +102,7 @@ def test_performance_unknown_target( "ethos-u65-512", False, True, + None, 0.5, 32, False, @@ -110,6 +115,7 @@ def test_performance_unknown_target( "ethos-u55-256", False, False, + None, 0.5, None, True, @@ -128,6 +134,7 @@ def test_performance_unknown_target( "ethos-u55-256", False, False, + None, 0.5, None, True, @@ -146,6 +153,7 @@ def test_performance_unknown_target( "ethos-u55-256", False, False, + None, "invalid", None, True, @@ -169,6 +177,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m clustering: bool, pruning_target: float | None, clustering_target: int | None, + optimization_profile: str | None, rewrite: bool, rewrite_target: str | None, rewrite_start: str | None, @@ -192,6 +201,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m model=str(model_type), pruning=pruning, clustering=clustering, + optimization_profile=optimization_profile, pruning_target=pruning_target, clustering_target=clustering_target, rewrite=rewrite, diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index 494ed89..0e9f0d6 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the helper classes.""" from __future__ import annotations +import re from pathlib import Path from typing import Any @@ -144,9 +145,38 @@ class TestCliActionResolver: def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: - """Test if the profile file is copied into the output directory.""" + """Test if the target profile file is copied into the output directory.""" test_target_profile_name = "ethos-u55-128" test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") - copy_profile_file_to_output_dir(test_target_profile_name, tmp_path) + copy_profile_file_to_output_dir( + test_target_profile_name, tmp_path, profile_to_copy="target_profile" + ) assert Path.is_file(test_file_path) + + +def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None: + """Test if the optimization profile file is copied into the output directory.""" + test_target_profile_name = "optimization" + test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") + + copy_profile_file_to_output_dir( + test_target_profile_name, tmp_path, profile_to_copy="optimization_profile" + ) + assert Path.is_file(test_file_path) + + +def test_copy_optimization_file_to_output_dir_error(tmp_path: Path) -> None: + """Test that the correct error is raised if the optimization + profile cannot be found.""" + test_target_profile_name = "wrong_file" + with pytest.raises( + RuntimeError, + match=re.escape( + "Failed to copy optimization_profile file: " + "[Errno 2] No such file or directory: '" + test_target_profile_name + "'" + ), + ): + copy_profile_file_to_output_dir( + test_target_profile_name, tmp_path, profile_to_copy="optimization_profile" + ) diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index e415284..564886b 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for main module.""" from __future__ import annotations @@ -164,6 +164,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: clustering=True, pruning_target=None, clustering_target=None, + optimization_profile="optimization", backend=None, rewrite=False, rewrite_target=None, @@ -194,6 +195,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: pruning_target=0.5, clustering_target=32, backend=None, + optimization_profile="optimization", rewrite=False, rewrite_target=None, rewrite_start=None, @@ -219,6 +221,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: clustering=False, pruning_target=None, clustering_target=None, + optimization_profile="optimization", backend=["some_backend"], rewrite=False, rewrite_target=None, @@ -244,6 +247,35 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: backend=None, ), ], + [ + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--backend", + "some_backend", + "--optimization-profile", + "optimization", + ], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.h5", + pruning=True, + clustering=False, + pruning_target=None, + clustering_target=None, + backend=["some_backend"], + optimization_profile="optimization", + rewrite=False, + rewrite_target=None, + rewrite_start=None, + rewrite_end=None, + dataset=None, + ), + ], ], ) def test_commands_execution( diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 599610d..05a5b55 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -1,15 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the common optimization module.""" +from contextlib import ExitStack as does_not_raises from pathlib import Path +from typing import Any from unittest.mock import MagicMock import pytest from mlia.core.context import ExecutionContext from mlia.nn.common import Optimizer +from mlia.nn.select import OptimizationSettings from mlia.nn.tensorflow.config import TFLiteModel +from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS +from mlia.target.common.optimization import add_common_optimization_params from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.config import load_profile from mlia.target.config import TargetProfile @@ -46,8 +52,14 @@ def test_optimizing_data_collector( {"optimization_type": "fake", "optimization_target": 42}, ] ] + training_parameters = {"batch_size": 32, "show_progress": False} context = ExecutionContext( - config_parameters={"common_optimizations": {"optimizations": optimizations}} + config_parameters={ + "common_optimizations": { + "optimizations": optimizations, + "training_parameters": [training_parameters], + } + } ) target_profile = MagicMock(spec=TargetProfile) @@ -61,7 +73,95 @@ def test_optimizing_data_collector( collector = OptimizingDataCollector(test_keras_model, target_profile) + optimize_model_mock = MagicMock(side_effect=collector.optimize_model) + monkeypatch.setattr( + "mlia.target.common.optimization.OptimizingDataCollector.optimize_model", + optimize_model_mock, + ) + opt_settings = [ + [ + OptimizationSettings( + item.get("optimization_type"), # type: ignore + item.get("optimization_target"), # type: ignore + item.get("layers_to_optimize"), # type: ignore + item.get("dataset"), # type: ignore + ) + for item in opt_configuration + ] + for opt_configuration in optimizations + ] + collector.set_context(context) collector.collect_data() - + assert optimize_model_mock.call_args.args[0] == opt_settings[0] + assert optimize_model_mock.call_args.args[1] == [training_parameters] assert fake_optimizer.invocation_count == 1 + + +@pytest.mark.parametrize( + "extra_args, error_to_raise", + [ + ( + { + "optimization_targets": [ + { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + } + ], + }, + does_not_raises(), + ), + ( + { + "optimization_profile": load_profile( + "src/mlia/resources/optimization_profiles/optimization.toml" + ) + }, + does_not_raises(), + ), + ( + { + "optimization_targets": { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + }, + }, + pytest.raises( + TypeError, match="Optimization targets value has wrong format." + ), + ), + ( + {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]}, + pytest.raises( + TypeError, match="Training Parameter values has wrong format." + ), + ), + ], +) +def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None: + """Test to check that optimization_targets and optimization_profiles are + correctly parsed.""" + advisor_parameters: dict = {} + + with error_to_raise: + add_common_optimization_params(advisor_parameters, extra_args) + if not extra_args.get("optimization_targets"): + assert advisor_parameters["common_optimizations"]["optimizations"] == [ + _DEFAULT_OPTIMIZATION_TARGETS + ] + else: + assert advisor_parameters["common_optimizations"]["optimizations"] == [ + extra_args["optimization_targets"] + ] + + if not extra_args.get("optimization_profile"): + assert advisor_parameters["common_optimizations"][ + "training_parameters" + ] == [None] + else: + assert advisor_parameters["common_optimizations"][ + "training_parameters" + ] == list(extra_args["optimization_profile"].values()) diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 487784d..363d614 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations @@ -7,6 +7,7 @@ from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast +from unittest.mock import MagicMock import pytest @@ -16,6 +17,8 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters @@ -129,3 +132,30 @@ def test_rewrite_function_autoload_fail() -> None: "Unable to load rewrite function 'invalid_module.invalid_function'" " for 'mock_rewrite'." ) + + +def test_rewrite_configuration_train_params( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test if we pass training parameters to the + rewrite configuration function they are passed to train_in_dir.""" + train_params = TrainingParameters( + batch_size=64, steps=24000, learning_rate=1e-5, show_progress=True + ) + + config_obj = RewriteConfiguration( + "fully_connected", + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + test_tfrecord_fp32, + train_params=train_params, + ) + + rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + train_mock = MagicMock(side_effect=train_in_dir) + monkeypatch.setattr("mlia.nn.rewrite.core.train.train_in_dir", train_mock) + rewriter_obj.apply_optimization() + + train_mock.assert_called_once() + assert train_mock.call_args.kwargs["train_params"] == train_params diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index 31628d2..92b7a3d 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -1,16 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module select.""" from __future__ import annotations from contextlib import ExitStack as does_not_raise +from dataclasses import asdict from pathlib import Path from typing import Any +from typing import cast import pytest import tensorflow as tf from mlia.core.errors import ConfigurationError +from mlia.nn.rewrite.core.rewrite import RewriteConfiguration +from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.select import get_optimizer from mlia.nn.select import MultiStageOptimizer from mlia.nn.select import OptimizationSettings @@ -135,6 +140,23 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration MultiStageOptimizer, "pruning: 0.5 - clustering: 32", ), + ( + OptimizationSettings( + optimization_type="rewrite", + optimization_target="fully_connected", # type: ignore + layers_to_optimize=None, + dataset=None, + ), + does_not_raise(), + RewritingOptimizer, + "rewrite: fully_connected", + ), + ( + RewriteConfiguration("fully_connected"), + does_not_raise(), + RewritingOptimizer, + "rewrite: fully_connected", + ), ], ) def test_get_optimizer( @@ -143,17 +165,58 @@ def test_get_optimizer( expected_type: type, expected_config: str, test_keras_model: Path, + test_tflite_model: Path, ) -> None: """Test function get_optimzer.""" - model = tf.keras.models.load_model(str(test_keras_model)) - with expected_error: + if ( + isinstance(config, OptimizationSettings) + and config.optimization_type == "rewrite" + ) or isinstance(config, RewriteConfiguration): + model = test_tflite_model + else: + model = tf.keras.models.load_model(str(test_keras_model)) optimizer = get_optimizer(model, config) assert isinstance(optimizer, expected_type) assert optimizer.optimization_config() == expected_config @pytest.mark.parametrize( + "rewrite_parameters", + [[None], [{"batch_size": 64, "learning_rate": 0.003}]], +) +@pytest.mark.skip_set_training_steps +def test_get_optimizer_training_parameters( + rewrite_parameters: list[dict], test_tflite_model: Path +) -> None: + """Test function get_optimzer with various combinations of parameters.""" + config = OptimizationSettings( + optimization_type="rewrite", + optimization_target="fully_connected", # type: ignore + layers_to_optimize=None, + dataset=None, + ) + optimizer = cast( + RewritingOptimizer, + get_optimizer(test_tflite_model, config, list(rewrite_parameters)), + ) + + assert len(rewrite_parameters) == 1 + + assert isinstance( + optimizer.optimizer_configuration.train_params, TrainingParameters + ) + if not rewrite_parameters[0]: + assert asdict(TrainingParameters()) == asdict( + optimizer.optimizer_configuration.train_params + ) + else: + assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict( + optimizer.optimizer_configuration.train_params + ) + + +@pytest.mark.parametrize( "params, expected_result", [ ( diff --git a/tests/test_target_config.py b/tests/test_target_config.py index 8055af0..56e9f11 100644 --- a/tests/test_target_config.py +++ b/tests/test_target_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend config module.""" from __future__ import annotations @@ -10,9 +10,9 @@ from mlia.backend.config import BackendType from mlia.backend.config import System from mlia.core.common import AdviceCategory from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES -from mlia.target.config import get_builtin_profile_path from mlia.target.config import get_builtin_supported_profile_names -from mlia.target.config import is_builtin_profile +from mlia.target.config import get_builtin_target_profile_path +from mlia.target.config import is_builtin_target_profile from mlia.target.config import load_profile from mlia.target.config import TargetInfo from mlia.target.config import TargetProfile @@ -33,23 +33,23 @@ def test_builtin_supported_profile_names() -> None: "tosa", ] for profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES: - assert is_builtin_profile(profile_name) - profile_file = get_builtin_profile_path(profile_name) + assert is_builtin_target_profile(profile_name) + profile_file = get_builtin_target_profile_path(profile_name) assert profile_file.is_file() def test_builtin_profile_files() -> None: """Test function 'get_bulitin_profile_file'.""" - profile_file = get_builtin_profile_path("cortex-a") + profile_file = get_builtin_target_profile_path("cortex-a") assert profile_file.is_file() - profile_file = get_builtin_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST") + profile_file = get_builtin_target_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST") assert not profile_file.exists() def test_load_profile() -> None: """Test getting profile data.""" - profile_file = get_builtin_profile_path("ethos-u55-256") + profile_file = get_builtin_target_profile_path("ethos-u55-256") assert load_profile(profile_file) == { "target": "ethos-u55", "mac": 256, diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 6e370d6..59d54b5 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Cortex-A MLIA module.""" from pathlib import Path @@ -46,7 +46,8 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: "optimization_type": "clustering", }, ] - ] + ], + "training_parameters": [None], }, } diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py index ca1ad82..120d0f5 100644 --- a/tests/test_target_registry.py +++ b/tests/test_target_registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the target registry module.""" from __future__ import annotations @@ -6,9 +6,11 @@ from __future__ import annotations import pytest from mlia.core.common import AdviceCategory -from mlia.target.config import get_builtin_profile_path +from mlia.target.config import get_builtin_optimization_profile_path +from mlia.target.config import get_builtin_target_profile_path from mlia.target.registry import all_supported_backends from mlia.target.registry import default_backends +from mlia.target.registry import get_optimization_profile from mlia.target.registry import is_supported from mlia.target.registry import profile from mlia.target.registry import registry @@ -146,6 +148,27 @@ def test_profile(target_profile: str) -> None: assert target_profile.startswith(cfg.target) # Test loading the file directly - profile_file = get_builtin_profile_path(target_profile) + profile_file = get_builtin_target_profile_path(target_profile) cfg = profile(profile_file) assert target_profile.startswith(cfg.target) + + +@pytest.mark.parametrize("optimization_profile", ["optimization"]) +def test_optimization_profile(optimization_profile: str) -> None: + """Test function optimization_profile().""" + + get_optimization_profile(optimization_profile) + + profile_file = get_builtin_optimization_profile_path(optimization_profile) + get_optimization_profile(profile_file) + + +@pytest.mark.parametrize("optimization_profile", ["non_valid_file"]) +def test_optimization_profile_non_valid_file(optimization_profile: str) -> None: + """Test function optimization_profile().""" + with pytest.raises( + ValueError, + match=f"optimization Profile '{optimization_profile}' is neither " + "a valid built-in optimization profile name or a valid file path.", + ): + get_optimization_profile(optimization_profile) diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index 36e52e9..cc47321 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA advisor.""" from pathlib import Path @@ -46,7 +46,8 @@ def test_configure_and_get_tosa_advisor( "optimization_type": "clustering", }, ] - ] + ], + "training_parameters": [None], }, "tosa_inference_advisor": { "model": str(test_tflite_model), diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py index c1c9876..1ccbd1c 100644 --- a/tests/test_utils_filesystem.py +++ b/tests/test_utils_filesystem.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the filesystem module.""" import contextlib @@ -10,6 +10,7 @@ from mlia.utils.filesystem import all_files_exist from mlia.utils.filesystem import all_paths_valid from mlia.utils.filesystem import copy_all from mlia.utils.filesystem import get_mlia_resources +from mlia.utils.filesystem import get_mlia_target_optimization_dir from mlia.utils.filesystem import get_mlia_target_profiles_dir from mlia.utils.filesystem import get_vela_config from mlia.utils.filesystem import recreate_directory @@ -37,6 +38,11 @@ def test_get_mlia_target_profiles() -> None: assert get_mlia_target_profiles_dir().is_dir() +def test_get_mlia_target_optimizations() -> None: + """Test target profiles getter.""" + assert get_mlia_target_optimization_dir().is_dir() + + @pytest.mark.parametrize("raise_exception", [True, False]) def test_temp_file(raise_exception: bool) -> None: """Test temp_file context manager.""" |