aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-02-15 14:50:58 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-14 15:45:40 +0000
commit0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch)
tree09b40b939acbe0bcf02dcc77a7ed7ce4aba94322
parent09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff)
downloadmlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization Resolves: MLIA-1004 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
-rw-r--r--.pre-commit-config.yaml4
-rw-r--r--README.md43
-rw-r--r--src/mlia/api.py10
-rw-r--r--src/mlia/cli/commands.py4
-rw-r--r--src/mlia/cli/helpers.py17
-rw-r--r--src/mlia/cli/main.py20
-rw-r--r--src/mlia/cli/options.py13
-rw-r--r--src/mlia/nn/select.py33
-rw-r--r--src/mlia/resources/optimization_profiles/optimization.toml11
-rw-r--r--src/mlia/target/common/optimization.py36
-rw-r--r--src/mlia/target/config.py28
-rw-r--r--src/mlia/target/registry.py39
-rw-r--r--src/mlia/utils/filesystem.py5
-rw-r--r--tests/conftest.py18
-rw-r--r--tests/test_cli_commands.py14
-rw-r--r--tests/test_cli_helpers.py36
-rw-r--r--tests/test_cli_main.py34
-rw-r--r--tests/test_common_optimization.py106
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py32
-rw-r--r--tests/test_nn_select.py69
-rw-r--r--tests/test_target_config.py16
-rw-r--r--tests/test_target_cortex_a_advisor.py5
-rw-r--r--tests/test_target_registry.py29
-rw-r--r--tests/test_target_tosa_advisor.py5
-rw-r--r--tests/test_utils_filesystem.py8
25 files changed, 557 insertions, 78 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 67bd90b..b601b03 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# Pre-commit checks
#
@@ -70,7 +70,7 @@ repos:
rev: v0.12.0
hooks:
- id: markdownlint
- args: ["-r", "~MD024,~MD002"]
+ args: ["-r", "~MD024,~MD002,~MD013"]
- repo: https://github.com/ryanrhee/shellcheck-py
rev: v0.9.0.5
diff --git a/README.md b/README.md
index ee53247..e24dded 100644
--- a/README.md
+++ b/README.md
@@ -209,6 +209,44 @@ mlia optimize ~/models/ds_cnn_large_fp32.tflite \
--rewrite-end MobileNet/fc1/BiasAdd
```
+### optimization Profiles
+
+Training parameters for rewrites can be specified.
+
+There are a number of predefined profiles:
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints |
+| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: |
+| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None |
+
+```bash
+##### An example for using optimization Profiles
+mlia optimize ~/models/ds_cnn_large_fp32.tflite \
+ --target-profile ethos-u55-256 \
+ --optimization-profile optimization \
+ --rewrite \
+ --dataset input.tfrec \
+ --rewrite-target fully_connected \
+ --rewrite-start MobileNet/avg_pool/AvgPool \
+ --rewrite-end MobileNet/fc1/BiasAdd_
+```
+
+#### Custom optimization Profiles
+
+For the _custom optimization profiles_, the configuration file for a custom
+optimization profile is passed as path and needs to conform to the TOML file format.
+Each optimization in MLIA has a pre-defined set of parameters which need to be present
+in the config file. When using the built-in optimization profiles, the appropriate
+toml file is copied to `mlia-output` and can be used to understand what parameters
+apply for each optimization.
+
+*Example:*
+
+``` bash
+# for custom profiles
+mlia ops --optimization-profile ~/my_custom_optimization_profile.toml
+```
+
# Target profiles
The targets currently supported are described in the sections below.
@@ -275,8 +313,9 @@ For more information, see TOSA Checker's:
For the _custom target profiles_, the configuration file for a custom
target profile is passed as path and needs to conform to the TOML file format.
Each target in MLIA has a pre-defined set of parameters which need to be present
-in the config file. The built-in target profiles (in `src/mlia/resources/target_profiles`)
-can be used to understand what parameters apply for each target.
+in the config file. When using the built-in target profiles, the appropriate
+toml file is copied to `mlia-output` and can be used to understand what parameters
+apply for each target.
*Example:*
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 7adae48..3901f56 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
from __future__ import annotations
@@ -10,6 +10,7 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
+from mlia.target.registry import get_optimization_profile
from mlia.target.registry import profile
from mlia.target.registry import registry as target_registry
@@ -20,6 +21,7 @@ def get_advice(
target_profile: str,
model: str | Path,
category: set[str],
+ optimization_profile: str | None = None,
optimization_targets: list[dict[str, Any]] | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
@@ -69,9 +71,9 @@ def get_advice(
target_profile,
model,
optimization_targets=optimization_targets,
+ optimization_profile=optimization_profile,
backends=backends,
)
-
advisor.run(context)
@@ -82,6 +84,10 @@ def get_advisor(
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
+ if extra_args.get("optimization_profile"):
+ extra_args["optimization_profile"] = get_optimization_profile(
+ extra_args["optimization_profile"]
+ )
target = profile(target_profile).target
factory_function = target_registry.items[target].advisor_factory_func
return factory_function(
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 7af41d9..fcba302 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI commands module.
@@ -104,6 +104,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments
clustering: bool,
pruning_target: float | None,
clustering_target: int | None,
+ optimization_profile: str | None = None,
rewrite: bool | None = None,
rewrite_target: str | None = None,
rewrite_start: str | None = None,
@@ -166,6 +167,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments
model,
{"optimization"},
optimization_targets=opt_params,
+ optimization_profile=optimization_profile,
context=ctx,
backends=validated_backend,
)
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index 824db1b..7b38577 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,8 +1,9 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helpers."""
from __future__ import annotations
+import importlib
from pathlib import Path
from shutil import copy
from typing import Any
@@ -12,8 +13,6 @@ from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.utils import is_keras_model
-from mlia.target.config import get_builtin_profile_path
-from mlia.target.config import is_builtin_profile
from mlia.utils.types import is_list_of
@@ -120,12 +119,16 @@ class CLIActionResolver(ActionResolver):
def copy_profile_file_to_output_dir(
- target_profile: str | Path, output_dir: str | Path
+ target_profile: str | Path, output_dir: str | Path, profile_to_copy: str
) -> bool:
"""Copy the target profile file to the output directory."""
+ get_func_name = "get_builtin_" + profile_to_copy + "_path"
+ get_func = getattr(importlib.import_module("mlia.target.config"), get_func_name)
+ is_func_name = "is_builtin_" + profile_to_copy
+ is_func = getattr(importlib.import_module("mlia.target.config"), is_func_name)
profile_file_path = (
- get_builtin_profile_path(cast(str, target_profile))
- if is_builtin_profile(target_profile)
+ get_func(cast(str, target_profile))
+ if is_func(target_profile)
else Path(target_profile)
)
output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
@@ -133,4 +136,4 @@ def copy_profile_file_to_output_dir(
copy(profile_file_path, output_file_path)
return True
except OSError as err:
- raise RuntimeError("Failed to copy profile file:", err.strerror) from err
+ raise RuntimeError(f"Failed to copy {profile_to_copy} file: {err}") from err
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 9e1b7cd..32d46a6 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI main entry point."""
from __future__ import annotations
@@ -203,11 +203,17 @@ def run_command(args: argparse.Namespace) -> int:
try:
logger.info("ML Inference Advisor %s", __version__)
- if copy_profile_file(ctx, func_args):
+ if copy_profile_file(ctx, func_args, "target_profile"):
logger.info(
"\nThe target profile (.toml) is copied to the output directory: %s",
ctx.output_dir,
)
+ if copy_profile_file(ctx, func_args, "optimization_profile"):
+ logger.info(
+ "\nThe optimization profile (.toml) is copied to "
+ "the output directory: %s",
+ ctx.output_dir,
+ )
args.func(**func_args)
return 0
except KeyboardInterrupt:
@@ -278,11 +284,13 @@ def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) ->
return run_command(args)
-def copy_profile_file(ctx: ExecutionContext, func_args: dict) -> bool:
- """If present, copy the target profile file to the output directory."""
- if func_args.get("target_profile"):
+def copy_profile_file(
+ ctx: ExecutionContext, func_args: dict, profile_to_copy: str
+) -> bool:
+ """If present, copy the selected profile file to the output directory."""
+ if func_args.get(profile_to_copy):
return copy_profile_file_to_output_dir(
- func_args["target_profile"], ctx.output_dir
+ func_args[profile_to_copy], ctx.output_dir, profile_to_copy
)
return False
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 57f54dd..1c55fed 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
from __future__ import annotations
@@ -15,6 +15,7 @@ from mlia.core.common import AdviceCategory
from mlia.core.errors import ConfigurationError
from mlia.core.typing import OutputFormat
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.target.registry import builtin_optimization_names
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
@@ -130,6 +131,16 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
help="Ending node in the graph of the subgraph to be rewritten.",
)
+ optimization_profiles = builtin_optimization_names()
+ multi_optimization_group.add_argument(
+ "-o",
+ "--optimization-profile",
+ required=False,
+ default="optimization",
+ help="Built-in optimization profile or path to the custom profile. "
+ f"Built-in optimization profiles are {', '.join(optimization_profiles)}. ",
+ )
+
def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 6947206..20950cc 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for optimization selection."""
from __future__ import annotations
@@ -117,6 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: tf.keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
+ training_parameters: list[dict | None] | None = None,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -135,10 +136,14 @@ def get_optimizer(
return RewritingOptimizer(model, config)
if isinstance(config, OptimizationSettings):
- return _get_optimizer(model, cast(OptimizationSettings, config))
+ return _get_optimizer(
+ model, cast(OptimizationSettings, config), training_parameters
+ )
if is_list_of(config, OptimizationSettings):
- return _get_optimizer(model, cast(List[OptimizationSettings], config))
+ return _get_optimizer(
+ model, cast(List[OptimizationSettings], config), training_parameters
+ )
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -146,16 +151,18 @@ def get_optimizer(
def _get_optimizer(
model: tf.keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
+ training_parameters: list[dict | None] | None = None,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
optimizer_configs = []
+
for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings:
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize, dataset
+ opt_type, opt_target, layers_to_optimize, dataset, training_parameters
)
optimizer_configs.append(opt_config)
@@ -165,13 +172,23 @@ def _get_optimizer(
return MultiStageOptimizer(model, optimizer_configs)
-def _get_rewrite_train_params() -> TrainingParameters:
+def _get_rewrite_params(
+ training_parameters: list[dict | None] | None = None,
+) -> list:
"""Get the rewrite TrainingParameters.
Return the default constructed TrainingParameters() per default, but can be
overwritten in the unit tests.
"""
- return TrainingParameters()
+ if training_parameters is None:
+ return [TrainingParameters()]
+
+ if training_parameters[0] is None:
+ train_params = TrainingParameters()
+ else:
+ train_params = TrainingParameters(**training_parameters[0])
+
+ return [train_params]
def _get_optimizer_configuration(
@@ -179,6 +196,7 @@ def _get_optimizer_configuration(
optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
+ training_parameters: list[dict | None] | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -199,11 +217,12 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
+ rewrite_params = _get_rewrite_params(training_parameters)
return RewriteConfiguration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=_get_rewrite_train_params(),
+ train_params=rewrite_params[0],
)
raise ConfigurationError(
diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml
new file mode 100644
index 0000000..623a763
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization.toml
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[training]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 5f359c5..8c5d184 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Data collector support for performance optimizations."""
from __future__ import annotations
@@ -50,6 +50,8 @@ class OptimizingDataCollector(ContextAwareDataCollector):
optimizations = self._get_optimization_settings(self.context)
+ training_parameters = self._get_training_settings(self.context)
+
if not optimizations or optimizations == [[]]:
raise FunctionalityNotSupportedError(
reason="No optimization targets provided",
@@ -75,17 +77,22 @@ class OptimizingDataCollector(ContextAwareDataCollector):
model = self.model # type: ignore
optimizers: list[Callable] = [
- partial(self.optimize_model, opts) for opts in opt_settings
+ partial(self.optimize_model, opts, training_parameters)
+ for opts in opt_settings
]
return self.optimize_and_estimate_performance(model, optimizers, opt_settings)
def optimize_model(
- self, opt_settings: list[OptimizationSettings], model: KerasModel | TFLiteModel
+ self,
+ opt_settings: list[OptimizationSettings],
+ training_parameters: list[dict | None],
+ model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
- optimizer = get_optimizer(model, opt_settings)
-
+ optimizer = get_optimizer(
+ model, opt_settings, training_parameters=training_parameters
+ )
opts_as_str = ", ".join(str(opt) for opt in opt_settings)
logger.info("Applying model optimizations - [%s]", opts_as_str)
optimizer.apply_optimization()
@@ -116,6 +123,16 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
+ def _get_training_settings(self, context: Context) -> list[dict]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ OptimizingDataCollector.name(),
+ "training_parameters",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
@staticmethod
def _parse_optimization_params(
optimizations: list[list[dict]],
@@ -210,10 +227,19 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
if not is_list_of(optimization_targets, dict):
raise TypeError("Optimization targets value has wrong format.")
+ rewrite_parameters = extra_args.get("optimization_profile")
+ if not rewrite_parameters:
+ training_parameters = None
+ else:
+ if not isinstance(rewrite_parameters, dict):
+ raise TypeError("Training Parameter values has wrong format.")
+ training_parameters = extra_args["optimization_profile"].get("training")
+
advisor_parameters.update(
{
"common_optimizations": {
"optimizations": [optimization_targets],
+ "training_parameters": [training_parameters],
},
}
)
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 3bc74fa..8ccdad8 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target configuration module."""
from __future__ import annotations
@@ -22,9 +22,10 @@ from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
+from mlia.utils.filesystem import get_mlia_target_optimization_dir
-def get_builtin_profile_path(target_profile: str) -> Path:
+def get_builtin_target_profile_path(target_profile: str) -> Path:
"""
Construct the path to the built-in target profile file.
@@ -33,6 +34,15 @@ def get_builtin_profile_path(target_profile: str) -> Path:
return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+def get_builtin_optimization_profile_path(optimization_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
+
+ No checks are performed.
+ """
+ return get_mlia_target_optimization_dir() / f"{optimization_profile}.toml"
+
+
@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
@@ -56,11 +66,19 @@ def get_builtin_supported_profile_names() -> list[str]:
BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def is_builtin_profile(profile_name: str | Path) -> bool:
+def is_builtin_target_profile(profile_name: str | Path) -> bool:
"""Check if the given profile name belongs to a built-in profile."""
return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
+BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"]
+
+
+def is_builtin_optimization_profile(optimization_name: str | Path) -> bool:
+ """Check if the given optimization name belongs to a built-in optimization."""
+ return optimization_name in BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
+
+
T = TypeVar("T", bound="TargetProfile")
@@ -93,8 +111,8 @@ class TargetProfile(ABC):
@classmethod
def load_profile(cls: type[T], target_profile: str | Path) -> T:
"""Load a target profile from built-in target profile name or file path."""
- if is_builtin_profile(target_profile):
- profile_file = get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_target_profile(target_profile):
+ profile_file = get_builtin_target_profile_path(cast(str, target_profile))
else:
profile_file = Path(target_profile)
return cls.load(profile_file)
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index b7b6193..b850284 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target module."""
from __future__ import annotations
@@ -13,9 +13,12 @@ from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.reporting import Column
from mlia.core.reporting import Table
+from mlia.target.config import BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
-from mlia.target.config import get_builtin_profile_path
-from mlia.target.config import is_builtin_profile
+from mlia.target.config import get_builtin_optimization_profile_path
+from mlia.target.config import get_builtin_target_profile_path
+from mlia.target.config import is_builtin_optimization_profile
+from mlia.target.config import is_builtin_target_profile
from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
from mlia.target.config import TargetProfile
@@ -44,13 +47,18 @@ def builtin_profile_names() -> list[str]:
return BUILTIN_SUPPORTED_PROFILE_NAMES
+def builtin_optimization_names() -> list[str]:
+ """Return a list of built-in profile names (not file paths)."""
+ return BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
+
+
@lru_cache
def profile(target_profile: str | Path) -> TargetProfile:
"""Get the target profile data (built-in or custom file)."""
if not target_profile:
raise ValueError("No valid target profile was provided.")
- if is_builtin_profile(target_profile):
- profile_file = get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_target_profile(target_profile):
+ profile_file = get_builtin_target_profile_path(cast(str, target_profile))
profile_ = create_target_profile(profile_file)
else:
profile_file = Path(target_profile)
@@ -65,6 +73,27 @@ def profile(target_profile: str | Path) -> TargetProfile:
return profile_
+def get_optimization_profile(optimization_profile: str | Path) -> dict:
+ """Get the optimization profile data (built-in or custom file)."""
+ if not optimization_profile:
+ raise ValueError("No valid optimization profile was provided.")
+ if is_builtin_optimization_profile(optimization_profile):
+ profile_file = get_builtin_optimization_profile_path(
+ cast(str, optimization_profile)
+ )
+ profile_dict = load_profile(profile_file)
+ else:
+ profile_file = Path(optimization_profile)
+ if profile_file.is_file():
+ profile_dict = load_profile(profile_file)
+ else:
+ raise ValueError(
+ f"optimization Profile '{optimization_profile}' is neither a valid "
+ "built-in optimization profile name or a valid file path."
+ )
+ return profile_dict
+
+
def get_target(target_profile: str | Path) -> str:
"""Return target for the provided target_profile."""
return profile(target_profile).target
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
index e3ef7db..5348d06 100644
--- a/src/mlia/utils/filesystem.py
+++ b/src/mlia/utils/filesystem.py
@@ -34,6 +34,11 @@ def get_mlia_target_profiles_dir() -> Path:
return get_mlia_resources() / "target_profiles"
+def get_mlia_target_optimization_dir() -> Path:
+ """Get the profiles file."""
+ return get_mlia_resources() / "optimization_profiles"
+
+
@contextmanager
def temp_file(suffix: str | None = None) -> Generator[Path, None, None]:
"""Create temp file and remove it after."""
diff --git a/tests/conftest.py b/tests/conftest.py
index 1092979..53bfb0c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@ from typing import Callable
from typing import Generator
from unittest.mock import MagicMock
+import _pytest
import numpy as np
import pytest
import tensorflow as tf
@@ -256,11 +257,16 @@ def fixture_test_tfrecord_fp32(
@pytest.fixture(scope="session", autouse=True)
-def set_training_steps() -> Generator[None, None, None]:
+def set_training_steps(
+ request: _pytest.fixtures.SubRequest,
+) -> Generator[None, None, None]:
"""Speed up tests by using MockTrainingParameters."""
- with pytest.MonkeyPatch.context() as monkeypatch:
- monkeypatch.setattr(
- "mlia.nn.select._get_rewrite_train_params",
- MagicMock(return_value=MockTrainingParameters()),
- )
+ if "set_training_steps" == request.fixturename:
yield
+ else:
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_params",
+ MagicMock(return_value=[MockTrainingParameters(), None, None]),
+ )
+ yield
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 1ce793f..1a9bbb8 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -52,13 +52,15 @@ def test_performance_unknown_target(
@pytest.mark.parametrize(
- "target_profile, pruning, clustering, pruning_target, clustering_target, "
- "rewrite, rewrite_target, rewrite_start, rewrite_end, expected_error",
+ "target_profile, pruning, clustering, optimization_profile, pruning_target, "
+ "clustering_target, rewrite, rewrite_target, rewrite_start, rewrite_end ,"
+ "expected_error",
[
[
"ethos-u55-256",
True,
False,
+ None,
0.5,
None,
False,
@@ -73,6 +75,7 @@ def test_performance_unknown_target(
False,
None,
None,
+ None,
True,
"fully_connected",
"sequential/flatten/Reshape",
@@ -83,6 +86,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
True,
False,
+ None,
0.5,
None,
True,
@@ -98,6 +102,7 @@ def test_performance_unknown_target(
"ethos-u65-512",
False,
True,
+ None,
0.5,
32,
False,
@@ -110,6 +115,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
False,
False,
+ None,
0.5,
None,
True,
@@ -128,6 +134,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
False,
False,
+ None,
0.5,
None,
True,
@@ -146,6 +153,7 @@ def test_performance_unknown_target(
"ethos-u55-256",
False,
False,
+ None,
"invalid",
None,
True,
@@ -169,6 +177,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m
clustering: bool,
pruning_target: float | None,
clustering_target: int | None,
+ optimization_profile: str | None,
rewrite: bool,
rewrite_target: str | None,
rewrite_start: str | None,
@@ -192,6 +201,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-m
model=str(model_type),
pruning=pruning,
clustering=clustering,
+ optimization_profile=optimization_profile,
pruning_target=pruning_target,
clustering_target=clustering_target,
rewrite=rewrite,
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 494ed89..0e9f0d6 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -1,8 +1,9 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
from __future__ import annotations
+import re
from pathlib import Path
from typing import Any
@@ -144,9 +145,38 @@ class TestCliActionResolver:
def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None:
- """Test if the profile file is copied into the output directory."""
+ """Test if the target profile file is copied into the output directory."""
test_target_profile_name = "ethos-u55-128"
test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
- copy_profile_file_to_output_dir(test_target_profile_name, tmp_path)
+ copy_profile_file_to_output_dir(
+ test_target_profile_name, tmp_path, profile_to_copy="target_profile"
+ )
assert Path.is_file(test_file_path)
+
+
+def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None:
+ """Test if the optimization profile file is copied into the output directory."""
+ test_target_profile_name = "optimization"
+ test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
+
+ copy_profile_file_to_output_dir(
+ test_target_profile_name, tmp_path, profile_to_copy="optimization_profile"
+ )
+ assert Path.is_file(test_file_path)
+
+
+def test_copy_optimization_file_to_output_dir_error(tmp_path: Path) -> None:
+ """Test that the correct error is raised if the optimization
+ profile cannot be found."""
+ test_target_profile_name = "wrong_file"
+ with pytest.raises(
+ RuntimeError,
+ match=re.escape(
+ "Failed to copy optimization_profile file: "
+ "[Errno 2] No such file or directory: '" + test_target_profile_name + "'"
+ ),
+ ):
+ copy_profile_file_to_output_dir(
+ test_target_profile_name, tmp_path, profile_to_copy="optimization_profile"
+ )
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index e415284..564886b 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for main module."""
from __future__ import annotations
@@ -164,6 +164,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=None,
clustering_target=None,
+ optimization_profile="optimization",
backend=None,
rewrite=False,
rewrite_target=None,
@@ -194,6 +195,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
pruning_target=0.5,
clustering_target=32,
backend=None,
+ optimization_profile="optimization",
rewrite=False,
rewrite_target=None,
rewrite_start=None,
@@ -219,6 +221,7 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=False,
pruning_target=None,
clustering_target=None,
+ optimization_profile="optimization",
backend=["some_backend"],
rewrite=False,
rewrite_target=None,
@@ -244,6 +247,35 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
backend=None,
),
],
+ [
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--backend",
+ "some_backend",
+ "--optimization-profile",
+ "optimization",
+ ],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ pruning=True,
+ clustering=False,
+ pruning_target=None,
+ clustering_target=None,
+ backend=["some_backend"],
+ optimization_profile="optimization",
+ rewrite=False,
+ rewrite_target=None,
+ rewrite_start=None,
+ rewrite_end=None,
+ dataset=None,
+ ),
+ ],
],
)
def test_commands_execution(
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 599610d..05a5b55 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,15 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from contextlib import ExitStack as does_not_raises
from pathlib import Path
+from typing import Any
from unittest.mock import MagicMock
import pytest
from mlia.core.context import ExecutionContext
from mlia.nn.common import Optimizer
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
+from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -46,8 +52,14 @@ def test_optimizing_data_collector(
{"optimization_type": "fake", "optimization_target": 42},
]
]
+ training_parameters = {"batch_size": 32, "show_progress": False}
context = ExecutionContext(
- config_parameters={"common_optimizations": {"optimizations": optimizations}}
+ config_parameters={
+ "common_optimizations": {
+ "optimizations": optimizations,
+ "training_parameters": [training_parameters],
+ }
+ }
)
target_profile = MagicMock(spec=TargetProfile)
@@ -61,7 +73,95 @@ def test_optimizing_data_collector(
collector = OptimizingDataCollector(test_keras_model, target_profile)
+ optimize_model_mock = MagicMock(side_effect=collector.optimize_model)
+ monkeypatch.setattr(
+ "mlia.target.common.optimization.OptimizingDataCollector.optimize_model",
+ optimize_model_mock,
+ )
+ opt_settings = [
+ [
+ OptimizationSettings(
+ item.get("optimization_type"), # type: ignore
+ item.get("optimization_target"), # type: ignore
+ item.get("layers_to_optimize"), # type: ignore
+ item.get("dataset"), # type: ignore
+ )
+ for item in opt_configuration
+ ]
+ for opt_configuration in optimizations
+ ]
+
collector.set_context(context)
collector.collect_data()
-
+ assert optimize_model_mock.call_args.args[0] == opt_settings[0]
+ assert optimize_model_mock.call_args.args[1] == [training_parameters]
assert fake_optimizer.invocation_count == 1
+
+
+@pytest.mark.parametrize(
+ "extra_args, error_to_raise",
+ [
+ (
+ {
+ "optimization_targets": [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ }
+ ],
+ },
+ does_not_raises(),
+ ),
+ (
+ {
+ "optimization_profile": load_profile(
+ "src/mlia/resources/optimization_profiles/optimization.toml"
+ )
+ },
+ does_not_raises(),
+ ),
+ (
+ {
+ "optimization_targets": {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ },
+ pytest.raises(
+ TypeError, match="Optimization targets value has wrong format."
+ ),
+ ),
+ (
+ {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]},
+ pytest.raises(
+ TypeError, match="Training Parameter values has wrong format."
+ ),
+ ),
+ ],
+)
+def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None:
+ """Test to check that optimization_targets and optimization_profiles are
+ correctly parsed."""
+ advisor_parameters: dict = {}
+
+ with error_to_raise:
+ add_common_optimization_params(advisor_parameters, extra_args)
+ if not extra_args.get("optimization_targets"):
+ assert advisor_parameters["common_optimizations"]["optimizations"] == [
+ _DEFAULT_OPTIMIZATION_TARGETS
+ ]
+ else:
+ assert advisor_parameters["common_optimizations"]["optimizations"] == [
+ extra_args["optimization_targets"]
+ ]
+
+ if not extra_args.get("optimization_profile"):
+ assert advisor_parameters["common_optimizations"][
+ "training_parameters"
+ ] == [None]
+ else:
+ assert advisor_parameters["common_optimizations"][
+ "training_parameters"
+ ] == list(extra_args["optimization_profile"].values())
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 487784d..363d614 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.rewrite."""
from __future__ import annotations
@@ -7,6 +7,7 @@ from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
+from unittest.mock import MagicMock
import pytest
@@ -16,6 +17,8 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
+from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
@@ -129,3 +132,30 @@ def test_rewrite_function_autoload_fail() -> None:
"Unable to load rewrite function 'invalid_module.invalid_function'"
" for 'mock_rewrite'."
)
+
+
+def test_rewrite_configuration_train_params(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test if we pass training parameters to the
+ rewrite configuration function they are passed to train_in_dir."""
+ train_params = TrainingParameters(
+ batch_size=64, steps=24000, learning_rate=1e-5, show_progress=True
+ )
+
+ config_obj = RewriteConfiguration(
+ "fully_connected",
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ test_tfrecord_fp32,
+ train_params=train_params,
+ )
+
+ rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ train_mock = MagicMock(side_effect=train_in_dir)
+ monkeypatch.setattr("mlia.nn.rewrite.core.train.train_in_dir", train_mock)
+ rewriter_obj.apply_optimization()
+
+ train_mock.assert_called_once()
+ assert train_mock.call_args.kwargs["train_params"] == train_params
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index 31628d2..92b7a3d 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -1,16 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module select."""
from __future__ import annotations
from contextlib import ExitStack as does_not_raise
+from dataclasses import asdict
from pathlib import Path
from typing import Any
+from typing import cast
import pytest
import tensorflow as tf
from mlia.core.errors import ConfigurationError
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.select import get_optimizer
from mlia.nn.select import MultiStageOptimizer
from mlia.nn.select import OptimizationSettings
@@ -135,6 +140,23 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
MultiStageOptimizer,
"pruning: 0.5 - clustering: 32",
),
+ (
+ OptimizationSettings(
+ optimization_type="rewrite",
+ optimization_target="fully_connected", # type: ignore
+ layers_to_optimize=None,
+ dataset=None,
+ ),
+ does_not_raise(),
+ RewritingOptimizer,
+ "rewrite: fully_connected",
+ ),
+ (
+ RewriteConfiguration("fully_connected"),
+ does_not_raise(),
+ RewritingOptimizer,
+ "rewrite: fully_connected",
+ ),
],
)
def test_get_optimizer(
@@ -143,17 +165,58 @@ def test_get_optimizer(
expected_type: type,
expected_config: str,
test_keras_model: Path,
+ test_tflite_model: Path,
) -> None:
"""Test function get_optimzer."""
- model = tf.keras.models.load_model(str(test_keras_model))
-
with expected_error:
+ if (
+ isinstance(config, OptimizationSettings)
+ and config.optimization_type == "rewrite"
+ ) or isinstance(config, RewriteConfiguration):
+ model = test_tflite_model
+ else:
+ model = tf.keras.models.load_model(str(test_keras_model))
optimizer = get_optimizer(model, config)
assert isinstance(optimizer, expected_type)
assert optimizer.optimization_config() == expected_config
@pytest.mark.parametrize(
+ "rewrite_parameters",
+ [[None], [{"batch_size": 64, "learning_rate": 0.003}]],
+)
+@pytest.mark.skip_set_training_steps
+def test_get_optimizer_training_parameters(
+ rewrite_parameters: list[dict], test_tflite_model: Path
+) -> None:
+ """Test function get_optimzer with various combinations of parameters."""
+ config = OptimizationSettings(
+ optimization_type="rewrite",
+ optimization_target="fully_connected", # type: ignore
+ layers_to_optimize=None,
+ dataset=None,
+ )
+ optimizer = cast(
+ RewritingOptimizer,
+ get_optimizer(test_tflite_model, config, list(rewrite_parameters)),
+ )
+
+ assert len(rewrite_parameters) == 1
+
+ assert isinstance(
+ optimizer.optimizer_configuration.train_params, TrainingParameters
+ )
+ if not rewrite_parameters[0]:
+ assert asdict(TrainingParameters()) == asdict(
+ optimizer.optimizer_configuration.train_params
+ )
+ else:
+ assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict(
+ optimizer.optimizer_configuration.train_params
+ )
+
+
+@pytest.mark.parametrize(
"params, expected_result",
[
(
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
index 8055af0..56e9f11 100644
--- a/tests/test_target_config.py
+++ b/tests/test_target_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend config module."""
from __future__ import annotations
@@ -10,9 +10,9 @@ from mlia.backend.config import BackendType
from mlia.backend.config import System
from mlia.core.common import AdviceCategory
from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
-from mlia.target.config import get_builtin_profile_path
from mlia.target.config import get_builtin_supported_profile_names
-from mlia.target.config import is_builtin_profile
+from mlia.target.config import get_builtin_target_profile_path
+from mlia.target.config import is_builtin_target_profile
from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
from mlia.target.config import TargetProfile
@@ -33,23 +33,23 @@ def test_builtin_supported_profile_names() -> None:
"tosa",
]
for profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES:
- assert is_builtin_profile(profile_name)
- profile_file = get_builtin_profile_path(profile_name)
+ assert is_builtin_target_profile(profile_name)
+ profile_file = get_builtin_target_profile_path(profile_name)
assert profile_file.is_file()
def test_builtin_profile_files() -> None:
"""Test function 'get_bulitin_profile_file'."""
- profile_file = get_builtin_profile_path("cortex-a")
+ profile_file = get_builtin_target_profile_path("cortex-a")
assert profile_file.is_file()
- profile_file = get_builtin_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST")
+ profile_file = get_builtin_target_profile_path("UNKNOWN_FILE_THAT_DOES_NOT_EXIST")
assert not profile_file.exists()
def test_load_profile() -> None:
"""Test getting profile data."""
- profile_file = get_builtin_profile_path("ethos-u55-256")
+ profile_file = get_builtin_target_profile_path("ethos-u55-256")
assert load_profile(profile_file) == {
"target": "ethos-u55",
"mac": 256,
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 6e370d6..59d54b5 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Cortex-A MLIA module."""
from pathlib import Path
@@ -46,7 +46,8 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
"optimization_type": "clustering",
},
]
- ]
+ ],
+ "training_parameters": [None],
},
}
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
index ca1ad82..120d0f5 100644
--- a/tests/test_target_registry.py
+++ b/tests/test_target_registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the target registry module."""
from __future__ import annotations
@@ -6,9 +6,11 @@ from __future__ import annotations
import pytest
from mlia.core.common import AdviceCategory
-from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import get_builtin_optimization_profile_path
+from mlia.target.config import get_builtin_target_profile_path
from mlia.target.registry import all_supported_backends
from mlia.target.registry import default_backends
+from mlia.target.registry import get_optimization_profile
from mlia.target.registry import is_supported
from mlia.target.registry import profile
from mlia.target.registry import registry
@@ -146,6 +148,27 @@ def test_profile(target_profile: str) -> None:
assert target_profile.startswith(cfg.target)
# Test loading the file directly
- profile_file = get_builtin_profile_path(target_profile)
+ profile_file = get_builtin_target_profile_path(target_profile)
cfg = profile(profile_file)
assert target_profile.startswith(cfg.target)
+
+
+@pytest.mark.parametrize("optimization_profile", ["optimization"])
+def test_optimization_profile(optimization_profile: str) -> None:
+ """Test function optimization_profile()."""
+
+ get_optimization_profile(optimization_profile)
+
+ profile_file = get_builtin_optimization_profile_path(optimization_profile)
+ get_optimization_profile(profile_file)
+
+
+@pytest.mark.parametrize("optimization_profile", ["non_valid_file"])
+def test_optimization_profile_non_valid_file(optimization_profile: str) -> None:
+ """Test function optimization_profile()."""
+ with pytest.raises(
+ ValueError,
+ match=f"optimization Profile '{optimization_profile}' is neither "
+ "a valid built-in optimization profile name or a valid file path.",
+ ):
+ get_optimization_profile(optimization_profile)
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index 36e52e9..cc47321 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA advisor."""
from pathlib import Path
@@ -46,7 +46,8 @@ def test_configure_and_get_tosa_advisor(
"optimization_type": "clustering",
},
]
- ]
+ ],
+ "training_parameters": [None],
},
"tosa_inference_advisor": {
"model": str(test_tflite_model),
diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py
index c1c9876..1ccbd1c 100644
--- a/tests/test_utils_filesystem.py
+++ b/tests/test_utils_filesystem.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the filesystem module."""
import contextlib
@@ -10,6 +10,7 @@ from mlia.utils.filesystem import all_files_exist
from mlia.utils.filesystem import all_paths_valid
from mlia.utils.filesystem import copy_all
from mlia.utils.filesystem import get_mlia_resources
+from mlia.utils.filesystem import get_mlia_target_optimization_dir
from mlia.utils.filesystem import get_mlia_target_profiles_dir
from mlia.utils.filesystem import get_vela_config
from mlia.utils.filesystem import recreate_directory
@@ -37,6 +38,11 @@ def test_get_mlia_target_profiles() -> None:
assert get_mlia_target_profiles_dir().is_dir()
+def test_get_mlia_target_optimizations() -> None:
+ """Test target profiles getter."""
+ assert get_mlia_target_optimization_dir().is_dir()
+
+
@pytest.mark.parametrize("raise_exception", [True, False])
def test_temp_file(raise_exception: bool) -> None:
"""Test temp_file context manager."""