From 0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Thu, 15 Feb 2024 14:50:58 +0000 Subject: feat: Enable rewrite parameterisation Enables user to provide a toml or default profile to change training settings for rewrite optimization Resolves: MLIA-1004 Signed-off-by: Nathan Bailey Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061 --- .pre-commit-config.yaml | 4 +- README.md | 43 ++++++++- src/mlia/api.py | 10 +- src/mlia/cli/commands.py | 4 +- src/mlia/cli/helpers.py | 17 ++-- src/mlia/cli/main.py | 20 ++-- src/mlia/cli/options.py | 13 ++- src/mlia/nn/select.py | 33 +++++-- .../optimization_profiles/optimization.toml | 11 +++ src/mlia/target/common/optimization.py | 36 ++++++- src/mlia/target/config.py | 28 +++++- src/mlia/target/registry.py | 39 +++++++- src/mlia/utils/filesystem.py | 5 + tests/conftest.py | 18 ++-- tests/test_cli_commands.py | 14 ++- tests/test_cli_helpers.py | 36 ++++++- tests/test_cli_main.py | 34 ++++++- tests/test_common_optimization.py | 106 ++++++++++++++++++++- tests/test_nn_rewrite_core_rewrite.py | 32 ++++++- tests/test_nn_select.py | 69 +++++++++++++- tests/test_target_config.py | 16 ++-- tests/test_target_cortex_a_advisor.py | 5 +- tests/test_target_registry.py | 29 +++++- tests/test_target_tosa_advisor.py | 5 +- tests/test_utils_filesystem.py | 8 +- 25 files changed, 557 insertions(+), 78 deletions(-) create mode 100644 src/mlia/resources/optimization_profiles/optimization.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67bd90b..b601b03 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 # Pre-commit checks # @@ -70,7 +70,7 @@ repos: rev: v0.12.0 hooks: - id: markdownlint - args: ["-r", "~MD024,~MD002"] + args: ["-r", "~MD024,~MD002,~MD013"] - repo: https://github.com/ryanrhee/shellcheck-py rev: v0.9.0.5 diff --git a/README.md b/README.md index ee53247..e24dded 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,44 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \ --rewrite-end MobileNet/fc1/BiasAdd ``` +### optimization Profiles + +Training parameters for rewrites can be specified. + +There are a number of predefined profiles: + +| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | +| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | +| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | + +```bash +##### An example for using optimization Profiles +mlia optimize ~/models/ds_cnn_large_fp32.tflite \ + --target-profile ethos-u55-256 \ + --optimization-profile optimization \ + --rewrite \ + --dataset input.tfrec \ + --rewrite-target fully_connected \ + --rewrite-start MobileNet/avg_pool/AvgPool \ + --rewrite-end MobileNet/fc1/BiasAdd_ +``` + +#### Custom optimization Profiles + +For the _custom optimization profiles_, the configuration file for a custom +optimization profile is passed as path and needs to conform to the TOML file format. +Each optimization in MLIA has a pre-defined set of parameters which need to be present +in the config file. When using the built-in optimization profiles, the appropriate +toml file is copied to `mlia-output` and can be used to understand what parameters +apply for each optimization. + +*Example:* + +``` bash +# for custom profiles +mlia ops --optimization-profile ~/my_custom_optimization_profile.toml +``` + # Target profiles The targets currently supported are described in the sections below. @@ -275,8 +313,9 @@ For more information, see TOSA Checker's: For the _custom target profiles_, the configuration file for a custom target profile is passed as path and needs to conform to the TOML file format. Each target in MLIA has a pre-defined set of parameters which need to be present -in the config file. The built-in target profiles (in `src/mlia/resources/target_profiles`) -can be used to understand what parameters apply for each target. +in the config file. When using the built-in target profiles, the appropriate +toml file is copied to `mlia-output` and can be used to understand what parameters +apply for each target. *Example:* diff --git a/src/mlia/api.py b/src/mlia/api.py index 7adae48..3901f56 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the API functions.""" from __future__ import annotations @@ -10,6 +10,7 @@ from typing import Any from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext +from mlia.target.registry import get_optimization_profile from mlia.target.registry import profile from mlia.target.registry import registry as target_registry @@ -20,6 +21,7 @@ def get_advice( target_profile: str, model: str | Path, category: set[str], + optimization_profile: str | None = None, optimization_targets: list[dict[str, Any]] | None = None, context: ExecutionContext | None = None, backends: list[str] | None = None, @@ -69,9 +71,9 @@ def get_advice( target_profile, model, optimization_targets=optimization_targets, + optimization_profile=optimization_profile, backends=backends, ) - advisor.run(context) @@ -82,6 +84,10 @@ def get_advisor( **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" + if extra_args.get("optimization_profile"): + extra_args["optimization_profile"] = get_optimization_profile( + extra_args["optimization_profile"] + ) target = profile(target_profile).target factory_function = target_registry.items[target].advisor_factory_func return factory_function( diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index 7af41d9..fcba302 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI commands module. @@ -104,6 +104,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments clustering: bool, pruning_target: float | None, clustering_target: int | None, + optimization_profile: str | None = None, rewrite: bool | None = None, rewrite_target: str | None = None, rewrite_start: str | None = None, @@ -166,6 +167,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments model, {"optimization"}, optimization_targets=opt_params, + optimization_profile=optimization_profile, context=ctx, backends=validated_backend, ) diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py index 824db1b..7b38577 100644 --- a/src/mlia/cli/helpers.py +++ b/src/mlia/cli/helpers.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for various helpers.""" from __future__ import annotations +import importlib from pathlib import Path from shutil import copy from typing import Any @@ -12,8 +13,6 @@ from mlia.cli.options import get_target_profile_opts from mlia.core.helpers import ActionResolver from mlia.nn.select import OptimizationSettings from mlia.nn.tensorflow.utils import is_keras_model -from mlia.target.config import get_builtin_profile_path -from mlia.target.config import is_builtin_profile from mlia.utils.types import is_list_of @@ -120,12 +119,16 @@ class CLIActionResolver(ActionResolver): def copy_profile_file_to_output_dir( - target_profile: str | Path, output_dir: str | Path + target_profile: str | Path, output_dir: str | Path, profile_to_copy: str ) -> bool: """Copy the target profile file to the output directory.""" + get_func_name = "get_builtin_" + profile_to_copy + "_path" + get_func = getattr(importlib.import_module("mlia.target.config"), get_func_name) + is_func_name = "is_builtin_" + profile_to_copy + is_func = getattr(importlib.import_module("mlia.target.config"), is_func_name) profile_file_path = ( - get_builtin_profile_path(cast(str, target_profile)) - if is_builtin_profile(target_profile) + get_func(cast(str, target_profile)) + if is_func(target_profile) else Path(target_profile) ) output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" @@ -133,4 +136,4 @@ def copy_profile_file_to_output_dir( copy(profile_file_path, output_file_path) return True except OSError as err: - raise RuntimeError("Failed to copy profile file:", err.strerror) from err + raise RuntimeError(f"Failed to copy {profile_to_copy} file: {err}") from err diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 9e1b7cd..32d46a6 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI main entry point.""" from __future__ import annotations @@ -203,11 +203,17 @@ def run_command(args: argparse.Namespace) -> int: try: logger.info("ML Inference Advisor %s", __version__) - if copy_profile_file(ctx, func_args): + if copy_profile_file(ctx, func_args, "target_profile"): logger.info( "\nThe target profile (.toml) is copied to the output directory: %s", ctx.output_dir, ) + if copy_profile_file(ctx, func_args, "optimization_profile"): + logger.info( + "\nThe optimization profile (.toml) is copied to " + "the output directory: %s", + ctx.output_dir, + ) args.func(**func_args) return 0 except KeyboardInterrupt: @@ -278,11 +284,13 @@ def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) -> return run_command(args) -def copy_profile_file(ctx: ExecutionContext, func_args: dict) -> bool: - """If present, copy the target profile file to the output directory.""" - if func_args.get("target_profile"): +def copy_profile_file( + ctx: ExecutionContext, func_args: dict, profile_to_copy: str +) -> bool: + """If present, copy the selected profile file to the output directory.""" + if func_args.get(profile_to_copy): return copy_profile_file_to_output_dir( - func_args["target_profile"], ctx.output_dir + func_args[profile_to_copy], ctx.output_dir, profile_to_copy ) return False diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 57f54dd..1c55fed 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the CLI options.""" from __future__ import annotations @@ -15,6 +15,7 @@ from mlia.core.common import AdviceCategory from mlia.core.errors import ConfigurationError from mlia.core.typing import OutputFormat from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.target.registry import builtin_optimization_names from mlia.target.registry import builtin_profile_names from mlia.target.registry import registry as target_registry @@ -130,6 +131,16 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None: help="Ending node in the graph of the subgraph to be rewritten.", ) + optimization_profiles = builtin_optimization_names() + multi_optimization_group.add_argument( + "-o", + "--optimization-profile", + required=False, + default="optimization", + help="Built-in optimization profile or path to the custom profile. " + f"Built-in optimization profiles are {', '.join(optimization_profiles)}. ", + ) + def add_model_options(parser: argparse.ArgumentParser) -> None: """Add model specific options.""" diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 6947206..20950cc 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for optimization selection.""" from __future__ import annotations @@ -117,6 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: tf.keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], + training_parameters: list[dict | None] | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -135,10 +136,14 @@ def get_optimizer( return RewritingOptimizer(model, config) if isinstance(config, OptimizationSettings): - return _get_optimizer(model, cast(OptimizationSettings, config)) + return _get_optimizer( + model, cast(OptimizationSettings, config), training_parameters + ) if is_list_of(config, OptimizationSettings): - return _get_optimizer(model, cast(List[OptimizationSettings], config)) + return _get_optimizer( + model, cast(List[OptimizationSettings], config), training_parameters + ) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -146,16 +151,18 @@ def get_optimizer( def _get_optimizer( model: tf.keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], + training_parameters: list[dict | None] | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] optimizer_configs = [] + for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings: _check_optimizer_params(opt_type, opt_target) opt_config = _get_optimizer_configuration( - opt_type, opt_target, layers_to_optimize, dataset + opt_type, opt_target, layers_to_optimize, dataset, training_parameters ) optimizer_configs.append(opt_config) @@ -165,13 +172,23 @@ def _get_optimizer( return MultiStageOptimizer(model, optimizer_configs) -def _get_rewrite_train_params() -> TrainingParameters: +def _get_rewrite_params( + training_parameters: list[dict | None] | None = None, +) -> list: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - return TrainingParameters() + if training_parameters is None: + return [TrainingParameters()] + + if training_parameters[0] is None: + train_params = TrainingParameters() + else: + train_params = TrainingParameters(**training_parameters[0]) + + return [train_params] def _get_optimizer_configuration( @@ -179,6 +196,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, + training_parameters: list[dict | None] | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -199,11 +217,12 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): + rewrite_params = _get_rewrite_params(training_parameters) return RewriteConfiguration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=_get_rewrite_train_params(), + train_params=rewrite_params[0], ) raise ConfigurationError( diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml new file mode 100644 index 0000000..623a763 --- /dev/null +++ b/src/mlia/resources/optimization_profiles/optimization.toml @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +[training] +batch_size = 32 +learning_rate = 1e-3 +show_progress = true +steps = 48000 +learning_rate_schedule = "cosine" +num_procs = 1 +num_threads = 0 diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index 5f359c5..8c5d184 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Data collector support for performance optimizations.""" from __future__ import annotations @@ -50,6 +50,8 @@ class OptimizingDataCollector(ContextAwareDataCollector): optimizations = self._get_optimization_settings(self.context) + training_parameters = self._get_training_settings(self.context) + if not optimizations or optimizations == [[]]: raise FunctionalityNotSupportedError( reason="No optimization targets provided", @@ -75,17 +77,22 @@ class OptimizingDataCollector(ContextAwareDataCollector): model = self.model # type: ignore optimizers: list[Callable] = [ - partial(self.optimize_model, opts) for opts in opt_settings + partial(self.optimize_model, opts, training_parameters) + for opts in opt_settings ] return self.optimize_and_estimate_performance(model, optimizers, opt_settings) def optimize_model( - self, opt_settings: list[OptimizationSettings], model: KerasModel | TFLiteModel + self, + opt_settings: list[OptimizationSettings], + training_parameters: list[dict | None], + model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" - optimizer = get_optimizer(model, opt_settings) - + optimizer = get_optimizer( + model, opt_settings, training_parameters=training_parameters + ) opts_as_str = ", ".join(str(opt) for opt in opt_settings) logger.info("Applying model optimizations - [%s]", opts_as_str) optimizer.apply_optimization() @@ -116,6 +123,16 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) + def _get_training_settings(self, context: Context) -> list[dict]: + """Get optimization settings.""" + return self.get_parameter( # type: ignore + OptimizingDataCollector.name(), + "training_parameters", + expected_type=list, + expected=False, + context=context, + ) + @staticmethod def _parse_optimization_params( optimizations: list[list[dict]], @@ -210,10 +227,19 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - if not is_list_of(optimization_targets, dict): raise TypeError("Optimization targets value has wrong format.") + rewrite_parameters = extra_args.get("optimization_profile") + if not rewrite_parameters: + training_parameters = None + else: + if not isinstance(rewrite_parameters, dict): + raise TypeError("Training Parameter values has wrong format.") + training_parameters = extra_args["optimization_profile"].get("training") + advisor_parameters.update( { "common_optimizations": { "optimizations": [optimization_targets], + "training_parameters": [training_parameters], }, } ) diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index 3bc74fa..8ccdad8 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Target configuration module.""" from __future__ import annotations @@ -22,9 +22,10 @@ from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.advisor import InferenceAdvisor from mlia.utils.filesystem import get_mlia_target_profiles_dir +from mlia.utils.filesystem import get_mlia_target_optimization_dir -def get_builtin_profile_path(target_profile: str) -> Path: +def get_builtin_target_profile_path(target_profile: str) -> Path: """ Construct the path to the built-in target profile file. @@ -33,6 +34,15 @@ def get_builtin_profile_path(target_profile: str) -> Path: return get_mlia_target_profiles_dir() / f"{target_profile}.toml" +def get_builtin_optimization_profile_path(optimization_profile: str) -> Path: + """ + Construct the path to the built-in target profile file. + + No checks are performed. + """ + return get_mlia_target_optimization_dir() / f"{optimization_profile}.toml" + + @lru_cache def load_profile(path: str | Path) -> dict[str, Any]: """Get settings for the provided target profile.""" @@ -56,11 +66,19 @@ def get_builtin_supported_profile_names() -> list[str]: BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names() -def is_builtin_profile(profile_name: str | Path) -> bool: +def is_builtin_target_profile(profile_name: str | Path) -> bool: """Check if the given profile name belongs to a built-in profile.""" return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES +BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"] + + +def is_builtin_optimization_profile(optimization_name: str | Path) -> bool: + """Check if the given optimization name belongs to a built-in optimization.""" + return optimization_name in BUILTIN_SUPPORTED_OPTIMIZATION_NAMES + + T = TypeVar("T", bound="TargetProfile") @@ -93,8 +111,8 @@ class TargetProfile(ABC): @classmethod def load_profile(cls: type[T], target_profile: str | Path) -> T: """Load a target profile from built-in target profile name or file path.""" - if is_builtin_profile(target_profile): - profile_file = get_builtin_profile_path(cast(str, target_profile)) + if is_builtin_target_profile(target_profile): + profile_file = get_builtin_target_profile_path(cast(str, target_profile)) else: profile_file = Path(target_profile) return cls.load(profile_file) diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py index b7b6193..b850284 100644 --- a/src/mlia/target/registry.py +++ b/src/mlia/target/registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Target module.""" from __future__ import annotations @@ -13,9 +13,12 @@ from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.reporting import Column from mlia.core.reporting import Table +from mlia.target.config import BUILTIN_SUPPORTED_OPTIMIZATION_NAMES from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES -from mlia.target.config import get_builtin_profile_path -from mlia.target.config import is_builtin_profile +from mlia.target.config import get_builtin_optimization_profile_path +from mlia.target.config import get_builtin_target_profile_path +from mlia.target.config import is_builtin_optimization_profile +from mlia.target.config import is_builtin_target_profile from mlia.target.config import load_profile from mlia.target.config import TargetInfo from mlia.target.config import TargetProfile @@ -44,13 +47,18 @@ def builtin_profile_names() -> list[str]: return BUILTIN_SUPPORTED_PROFILE_NAMES +def builtin_optimization_names() -> list[str]: + """Return a list of built-in profile names (not file paths).""" + return BUILTIN_SUPPORTED_OPTIMIZATION_NAMES + + @lru_cache def profile(target_profile: str | Path) -> TargetProfile: """Get the target profile data (built-in or custom file).""" if not target_profile: raise ValueError("No valid target profile was provided.") - if is_builtin_profile(target_profile): - profile_file = get_builtin_profile_path(cast(str, target_profile)) + if is_builtin_target_profile(target_profile): + profile_file = get_builtin_target_profile_path(cast(str, target_profile)) profile_ = create_target_profile(profile_file) else: profile_file = Path(target_profile) @@ -65,6 +73,27 @@ def profile(target_profile: str | Path) -> TargetProfile: return profile_ +def get_optimization_profile(optimization_profile: str | Path) -> dict: + """Get the optimization profile data (built-in or custom file).""" + if not optimization_profile: + raise ValueError("No valid optimization profile was provided.") + if is_builtin_optimization_profile(optimization_profile): + profile_file = get_builtin_optimization_profile_path( + cast(str, optimization_profile) + ) + profile_dict = load_profile(profile_file) + else: + profile_file = Path(optimization_profile) + if profile_file.is_file(): + profile_dict = load_profile(profile_file) + else: + raise ValueError( + f"optimization Profile '{optimization_profile}' is neither a valid " + "built-in optimization profile name or a valid file path." + ) + return profile_dict + + def get_target(target_profile: str | Path) -> str: """Return target for the provided target_profile.""" return profile(target_profile).target diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py index e3ef7db..5348d06 100644 --- a/src/mlia/utils/filesystem.py +++ b/src/mlia/utils/filesystem.py @@ -34,6 +34,11 @@ def get_mlia_target_profiles_dir() -> Path: return get_mlia_resources() / "target_profiles" +def get_mlia_target_optimization_dir() -> Path: + """Get the profiles file.""" + return get_mlia_resources() / "optimization_profiles" + + @contextmanager def temp_file(suffix: str | None = None) -> Generator[Path, None, None]: """Create temp file and remove it after.""" diff --git a/tests/conftest.py b/tests/conftest.py index 1092979..53bfb0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from typing import Callable from typing import Generator from unittest.mock import MagicMock +import _pytest import numpy as np import pytest import tensorflow as tf @@ -256,11 +257,16 @@ def fixture_test_tfrecord_fp32( @pytest.fixture(scope="session", autouse=True) -def set_training_steps() -> Generator[None, None, None]: +def set_training_steps( + request: _pytest.fixtures.SubRequest, +) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" - with pytest.MonkeyPatch.context() as monkeypatch: - monkeypatch.setattr( - "mlia.nn.select._get_rewrite_train_params", - MagicMock(return_value=MockTrainingParameters()), - ) + if "set_training_steps" == request.fixturename: yield + else: + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "mlia.nn.select._get_rewrite_params", + MagicMock(return_value=[MockTrainingParameters(), None, None]), + ) + yield diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 1ce793f..1a9bbb8 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -52,13 +52,15 @@ def test_performance_unknown_target( @pytest.mark.parametrize( - "target_profile, pruning, clustering, pruning_target, clustering_target, " - "rewrite, rewrite_target, rewrite_start, rewrite_end, expected_error", + "target_profile, pruning, clustering, optimization_profile, pruning_target, " + "clustering_target, rewrite, rewrite_target, rewrite_start, rewrite_end ," + "expected_error", [ [ "ethos-u55-256", True, False, + None, 0.5, None, False, @@ -73,6 +75,7 @@ def test_performance_unknown_target( False, None, None, + None, True, "fully_connected", "sequential/flatten/Reshape", @@ -83,6 +86,7 @@ def test_performance_unknown_target( "ethos-u55-256", True, False, + None, 0.5, None, True, @@ -98,6 +102,7 @@ def test_performance_unknown_target( "ethos-u65-512", False, True, + None, 0.5, 32, False, @@ -110,6 +115,7 @@ def test_performance_unknown_target( "ethos-u55-256", False, False, + None, 0.5, None, True, @@ -128,6 +134,7 @@ def test_performance_unknown_target( "ethos-u55-256", False, False, + None, 0.5, None, True, @@ -146,6 +153,7 @@ def test_performance_unknown_target( "ethos-u55-256", False, False, + None, "invalid", None, True, @@ -169,6 +177,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m clustering: bool, pruning_target: float | None, clustering_target: int | None, + optimization_profile: str | None, rewrite: bool, rewrite_target: str | None, rewrite_start: str | None, @@ -192,6 +201,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m model=str(model_type), pruning=pruning, clustering=clustering, + optimization_profile=optimization_profile, pruning_target=pruning_target, clustering_target=clustering_target, rewrite=rewrite, diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index 494ed89..0e9f0d6 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the helper classes.""" from __future__ import annotations +import re from pathlib import Path from typing import Any @@ -144,9 +145,38 @@ class TestCliActionResolver: def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: - """Test if the profile file is copied into the output directory.""" + """Test if the target profile file is copied into the output directory.""" test_target_profile_name = "ethos-u55-128" test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") - copy_profile_file_to_output_dir(test_target_profile_name, tmp_path) + copy_profile_file_to_output_dir( + test_target_profile_name, tmp_path, profile_to_copy="target_profile" + ) assert Path.is_file(test_file_path) + + +def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None: + """Test if the optimization profile file is copied into the output directory.""" + test_target_profile_name = "optimization" + test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") + + copy_profile_file_to_output_dir( + test_target_profile_name, tmp_path, profile_to_copy="optimization_profile" + ) + assert Path.is_file(test_file_path) + + +def test_copy_optimization_file_to_output_dir_error(tmp_path: Path) -> None: + """Test that the correct error is raised if the optimization + profile cannot be found.""" + test_target_profile_name = "wrong_file" + with pytest.raises( + RuntimeError, + match=re.escape( + "Failed to copy optimization_profile file: " + "[Errno 2] No such file or directory: '" + test_target_profile_name + "'" + ), + ): + copy_profile_file_to_output_dir( + test_target_profile_name, tmp_path, profile_to_copy="optimization_profile" + ) diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index e415284..564886b 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for main module.""" from __future__ import annotations @@ -164,6 +164,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: clustering=True, pruning_target=None, clustering_target=None, + optimization_profile="optimization", backend=None, rewrite=False, rewrite_target=None, @@ -194,6 +195,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: pruning_target=0.5, clustering_target=32, backend=None, + optimization_profile="optimization", rewrite=False, rewrite_target=None, rewrite_start=None, @@ -219,6 +221,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: clustering=False, pruning_target=None, clustering_target=None, + optimization_profile="optimization", backend=["some_backend"], rewrite=False, rewrite_target=None, @@ -244,6 +247,35 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: backend=None, ), ], + [ + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--backend", + "some_backend", + "--optimization-profile", + "optimization", + ], + call( + ctx=ANY, + target_profile="ethos-u55-256", + model="sample_model.h5", + pruning=True, + clustering=False, + pruning_target=None, + clustering_target=None, + backend=["some_backend"], + optimization_profile="optimization", + rewrite=False, + rewrite_target=None, + rewrite_start=None, + rewrite_end=None, + dataset=None, + ), + ], ], ) def test_commands_execution( diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 599610d..05a5b55 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -1,15 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the common optimization module.""" +from contextlib import ExitStack as does_not_raises from pathlib import Path +from typing import Any from unittest.mock import MagicMock import pytest from mlia.core.context import ExecutionContext from mlia.nn.common import Optimizer +from mlia.nn.select import OptimizationSettings from mlia.nn.tensorflow.config import TFLiteModel +from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS +from mlia.target.common.optimization import add_common_optimization_params from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.config import load_profile from mlia.target.config import TargetProfile @@ -46,8 +52,14 @@ def test_optimizing_data_collector( {"optimization_type": "fake", "optimization_target": 42}, ] ] + training_parameters = {"batch_size": 32, "show_progress": False} context = ExecutionContext( - config_parameters={"common_optimizations": {"optimizations": optimizations}} + config_parameters={ + "common_optimizations": { + "optimizations": optimizations, + "training_parameters": [training_parameters], + } + } ) target_profile = MagicMock(spec=TargetProfile) @@ -61,7 +73,95 @@ def test_optimizing_data_collector( collector = OptimizingDataCollector(test_keras_model, target_profile) + optimize_model_mock = MagicMock(side_effect=collector.optimize_model) + monkeypatch.setattr( + "mlia.target.common.optimization.OptimizingDataCollector.optimize_model", + optimize_model_mock, + ) + opt_settings = [ + [ + OptimizationSettings( + item.get("optimization_type"), # type: ignore + item.get("optimization_target"), # type: ignore + item.get("layers_to_optimize"), # type: ignore + item.get("dataset"), # type: ignore + ) + for item in opt_configuration + ] + for opt_configuration in optimizations + ] + collector.set_context(context) collector.collect_data() - + assert optimize_model_mock.call_args.args[0] == opt_settings[0] + assert optimize_model_mock.call_args.args[1] == [training_parameters] assert fake_optimizer.invocation_count == 1 + + +@pytest.mark.parametrize( + "extra_args, error_to_raise", + [ + ( + { + "optimization_targets": [ + { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + } + ], + }, + does_not_raises(), + ), + ( + { + "optimization_profile": load_profile( + "src/mlia/resources/optimization_profiles/optimization.toml" + ) + }, + does_not_raises(), + ), + ( + { + "optimization_targets": { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + }, + }, + pytest.raises( + TypeError, match="Optimization targets value has wrong format." + ), + ), + ( + {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]}, + pytest.raises( + TypeError, match="Training Parameter values has wrong format." + ), + ), + ], +) +def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None: + """Test to check that optimization_targets and optimization_profiles are + correctly parsed.""" + advisor_parameters: dict = {} + + with error_to_raise: + add_common_optimization_params(advisor_parameters, extra_args) + if not extra_args.get("optimization_targets"): + assert advisor_parameters["common_optimizations"]["optimizations"] == [ + _DEFAULT_OPTIMIZATION_TARGETS + ] + else: + assert advisor_parameters["common_optimizations"]["optimizations"] == [ + extra_args["optimization_targets"] + ] + + if not extra_args.get("optimization_profile"): + assert advisor_parameters["common_optimizations"][ + "training_parameters" + ] == [None] + else: + assert advisor_parameters["common_optimizations"][ + "training_parameters" + ] == list(extra_args["optimization_profile"].values()) diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 487784d..363d614 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations @@ -7,6 +7,7 @@ from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast +from unittest.mock import MagicMock import pytest @@ -16,6 +17,8 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters @@ -129,3 +132,30 @@ def test_rewrite_function_autoload_fail() -> None: "Unable to load rewrite function 'invalid_module.invalid_function'" " for 'mock_rewrite'." ) + + +def test_rewrite_configuration_train_params( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test if we pass training parameters to the + rewrite configuration function they are passed to train_in_dir.""" + train_params = TrainingParameters( + batch_size=64, steps=24000, learning_rate=1e-5, show_progress=True + ) + + config_obj = RewriteConfiguration( + "fully_connected", + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + test_tfrecord_fp32, + train_params=train_params, + ) + + rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + train_mock = MagicMock(side_effect=train_in_dir) + monkeypatch.setattr("mlia.nn.rewrite.core.train.train_in_dir", train_mock) + rewriter_obj.apply_optimization() + + train_mock.assert_called_once() + assert train_mock.call_args.kwargs["train_params"] == train_params diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index 31628d2..92b7a3d 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -1,16 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module select.""" from __future__ import annotations from contextlib import ExitStack as does_not_raise +from dataclasses import asdict from pathlib import Path from typing import Any +from typing import cast import pytest import tensorflow as tf from mlia.core.errors import ConfigurationError +from mlia.nn.rewrite.core.rewrite import RewriteConfiguration +from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.select import get_optimizer from mlia.nn.select import MultiStageOptimizer from mlia.nn.select import OptimizationSettings @@ -135,6 +140,23 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration MultiStageOptimizer, "pruning: 0.5 - clustering: 32", ), + ( + OptimizationSettings( + optimization_type="rewrite", + optimization_target="fully_connected", # type: ignore + layers_to_optimize=None, + dataset=None, + ), + does_not_raise(), + RewritingOptimizer, + "rewrite: fully_connected", + ), + ( + RewriteConfiguration("fully_connected"), + does_not_raise(), + RewritingOptimizer, + "rewrite: fully_connected", + ), ], ) def test_get_optimizer( @@ -143,16 +165,57 @@ def test_get_optimizer( expected_type: type, expected_config: str, test_keras_model: Path, + test_tflite_model: Path, ) -> None: """Test function get_optimzer.""" - model = tf.keras.models.load_model(str(test_keras_model)) - with expected_error: + if ( + isinstance(config, OptimizationSettings) + and config.optimization_type == "rewrite" + ) or isinstance(config, RewriteConfiguration): + model = test_tflite_model + else: + model = tf.keras.models.load_model(str(test_keras_model)) optimizer = get_optimizer(model, config) assert isinstance(optimizer, expected_type) assert optimizer.optimization_config() == expected_config +@pytest.mark.parametrize( + "rewrite_parameters", + [[None], [{"batch_size": 64, "learning_rate": 0.003}]], +) +@pytest.mark.skip_set_training_steps +def test_get_optimizer_training_parameters( + rewrite_parameters: list[dict], test_tflite_model: Path +) -> None: + """Test function get_optimzer with various combinations of parameters.""" + config = OptimizationSettings( + optimization_type="rewrite", + optimization_target="fully_connected", # type: ignore + layers_to_optimize=None, + dataset=None, + ) + optimizer = cast( + RewritingOptimizer, + get_optimizer(test_tflite_model, config, list(rewrite_parameters)), + ) + + assert len(rewrite_parameters) == 1 + + assert isinstance( + optimizer.optimizer_configuration.train_params, TrainingParameters + ) + if not rewrite_parameters[0]: + assert asdict(TrainingParameters()) == asdict( + optimizer.optimizer_configuration.train_params + ) + else: + assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict( + optimizer.optimizer_configuration.train_params + ) + + @pytest.mark.parametrize( "params, expected_result", [ diff --git a/tests/test_target_config.py b/tests/test_target_config.py index 8055af0..56e9f11 100644 --- a/tests/test_target_config.py +++ b/tests/test_target_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend config module.""" from __future__ import annotations @@ -10,9 +10,9 @@ from mlia.backend.config import BackendType from mlia.backend.config import System from mlia.core.common import AdviceCategory from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES -from mlia.target.config import get_builtin_profile_path from mlia.target.config import get_builtin_supported_profile_names -from mlia.target.config import is_builtin_profile +from mlia.target.config import get_builtin_target_profile_path +from mlia.target.config import is_builtin_target_profile from mlia.target.config import load_profile from mlia.target.config import TargetInfo from mlia.target.config import TargetProfile @@ -33,23 +33,23 @@ def test_builtin_supported_profile_names() -> None: "tosa", ] for profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES: - assert is_builtin_profile(profile_name) - profile_file = get_builtin_profile_path(profile_name) + assert is_builtin_target_profile(profile_name) + profile_file = get_builtin_target_profile_path(profile_name) assert profile_file.is_file() def test_builtin_profile_files() -> None: """Test function 'get_bulitin_profile_file'.""" - profile_file = get_builtin_profile_path("cortex-a") + profile_file = get_builtin_target_profile_path("cortex-a") assert profile_file.is_file() - profile_file = get_builtin_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST") + profile_file = get_builtin_target_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST") assert not profile_file.exists() def test_load_profile() -> None: """Test getting profile data.""" - profile_file = get_builtin_profile_path("ethos-u55-256") + profile_file = get_builtin_target_profile_path("ethos-u55-256") assert load_profile(profile_file) == { "target": "ethos-u55", "mac": 256, diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 6e370d6..59d54b5 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Cortex-A MLIA module.""" from pathlib import Path @@ -46,7 +46,8 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: "optimization_type": "clustering", }, ] - ] + ], + "training_parameters": [None], }, } diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py index ca1ad82..120d0f5 100644 --- a/tests/test_target_registry.py +++ b/tests/test_target_registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the target registry module.""" from __future__ import annotations @@ -6,9 +6,11 @@ from __future__ import annotations import pytest from mlia.core.common import AdviceCategory -from mlia.target.config import get_builtin_profile_path +from mlia.target.config import get_builtin_optimization_profile_path +from mlia.target.config import get_builtin_target_profile_path from mlia.target.registry import all_supported_backends from mlia.target.registry import default_backends +from mlia.target.registry import get_optimization_profile from mlia.target.registry import is_supported from mlia.target.registry import profile from mlia.target.registry import registry @@ -146,6 +148,27 @@ def test_profile(target_profile: str) -> None: assert target_profile.startswith(cfg.target) # Test loading the file directly - profile_file = get_builtin_profile_path(target_profile) + profile_file = get_builtin_target_profile_path(target_profile) cfg = profile(profile_file) assert target_profile.startswith(cfg.target) + + +@pytest.mark.parametrize("optimization_profile", ["optimization"]) +def test_optimization_profile(optimization_profile: str) -> None: + """Test function optimization_profile().""" + + get_optimization_profile(optimization_profile) + + profile_file = get_builtin_optimization_profile_path(optimization_profile) + get_optimization_profile(profile_file) + + +@pytest.mark.parametrize("optimization_profile", ["non_valid_file"]) +def test_optimization_profile_non_valid_file(optimization_profile: str) -> None: + """Test function optimization_profile().""" + with pytest.raises( + ValueError, + match=f"optimization Profile '{optimization_profile}' is neither " + "a valid built-in optimization profile name or a valid file path.", + ): + get_optimization_profile(optimization_profile) diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index 36e52e9..cc47321 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA advisor.""" from pathlib import Path @@ -46,7 +46,8 @@ def test_configure_and_get_tosa_advisor( "optimization_type": "clustering", }, ] - ] + ], + "training_parameters": [None], }, "tosa_inference_advisor": { "model": str(test_tflite_model), diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py index c1c9876..1ccbd1c 100644 --- a/tests/test_utils_filesystem.py +++ b/tests/test_utils_filesystem.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the filesystem module.""" import contextlib @@ -10,6 +10,7 @@ from mlia.utils.filesystem import all_files_exist from mlia.utils.filesystem import all_paths_valid from mlia.utils.filesystem import copy_all from mlia.utils.filesystem import get_mlia_resources +from mlia.utils.filesystem import get_mlia_target_optimization_dir from mlia.utils.filesystem import get_mlia_target_profiles_dir from mlia.utils.filesystem import get_vela_config from mlia.utils.filesystem import recreate_directory @@ -37,6 +38,11 @@ def test_get_mlia_target_profiles() -> None: assert get_mlia_target_profiles_dir().is_dir() +def test_get_mlia_target_optimizations() -> None: + """Test target profiles getter.""" + assert get_mlia_target_optimization_dir().is_dir() + + @pytest.mark.parametrize("raise_exception", [True, False]) def test_temp_file(raise_exception: bool) -> None: """Test temp_file context manager.""" -- cgit v1.2.1