aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-02-15 14:50:58 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-14 15:45:40 +0000
commit0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch)
tree09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /src
parent09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff)
downloadmlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization Resolves: MLIA-1004 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'src')
-rw-r--r--src/mlia/api.py10
-rw-r--r--src/mlia/cli/commands.py4
-rw-r--r--src/mlia/cli/helpers.py17
-rw-r--r--src/mlia/cli/main.py20
-rw-r--r--src/mlia/cli/options.py13
-rw-r--r--src/mlia/nn/select.py33
-rw-r--r--src/mlia/resources/optimization_profiles/optimization.toml11
-rw-r--r--src/mlia/target/common/optimization.py36
-rw-r--r--src/mlia/target/config.py28
-rw-r--r--src/mlia/target/registry.py39
-rw-r--r--src/mlia/utils/filesystem.py5
11 files changed, 177 insertions, 39 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 7adae48..3901f56 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
from __future__ import annotations
@@ -10,6 +10,7 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
+from mlia.target.registry import get_optimization_profile
from mlia.target.registry import profile
from mlia.target.registry import registry as target_registry
@@ -20,6 +21,7 @@ def get_advice(
target_profile: str,
model: str | Path,
category: set[str],
+ optimization_profile: str | None = None,
optimization_targets: list[dict[str, Any]] | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
@@ -69,9 +71,9 @@ def get_advice(
target_profile,
model,
optimization_targets=optimization_targets,
+ optimization_profile=optimization_profile,
backends=backends,
)
-
advisor.run(context)
@@ -82,6 +84,10 @@ def get_advisor(
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
+ if extra_args.get("optimization_profile"):
+ extra_args["optimization_profile"] = get_optimization_profile(
+ extra_args["optimization_profile"]
+ )
target = profile(target_profile).target
factory_function = target_registry.items[target].advisor_factory_func
return factory_function(
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 7af41d9..fcba302 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI commands module.
@@ -104,6 +104,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments
clustering: bool,
pruning_target: float | None,
clustering_target: int | None,
+ optimization_profile: str | None = None,
rewrite: bool | None = None,
rewrite_target: str | None = None,
rewrite_start: str | None = None,
@@ -166,6 +167,7 @@ def optimize( # pylint: disable=too-many-locals,too-many-arguments
model,
{"optimization"},
optimization_targets=opt_params,
+ optimization_profile=optimization_profile,
context=ctx,
backends=validated_backend,
)
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index 824db1b..7b38577 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,8 +1,9 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helpers."""
from __future__ import annotations
+import importlib
from pathlib import Path
from shutil import copy
from typing import Any
@@ -12,8 +13,6 @@ from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.utils import is_keras_model
-from mlia.target.config import get_builtin_profile_path
-from mlia.target.config import is_builtin_profile
from mlia.utils.types import is_list_of
@@ -120,12 +119,16 @@ class CLIActionResolver(ActionResolver):
def copy_profile_file_to_output_dir(
- target_profile: str | Path, output_dir: str | Path
+ target_profile: str | Path, output_dir: str | Path, profile_to_copy: str
) -> bool:
"""Copy the target profile file to the output directory."""
+ get_func_name = "get_builtin_" + profile_to_copy + "_path"
+ get_func = getattr(importlib.import_module("mlia.target.config"), get_func_name)
+ is_func_name = "is_builtin_" + profile_to_copy
+ is_func = getattr(importlib.import_module("mlia.target.config"), is_func_name)
profile_file_path = (
- get_builtin_profile_path(cast(str, target_profile))
- if is_builtin_profile(target_profile)
+ get_func(cast(str, target_profile))
+ if is_func(target_profile)
else Path(target_profile)
)
output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
@@ -133,4 +136,4 @@ def copy_profile_file_to_output_dir(
copy(profile_file_path, output_file_path)
return True
except OSError as err:
- raise RuntimeError("Failed to copy profile file:", err.strerror) from err
+ raise RuntimeError(f"Failed to copy {profile_to_copy} file: {err}") from err
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 9e1b7cd..32d46a6 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI main entry point."""
from __future__ import annotations
@@ -203,11 +203,17 @@ def run_command(args: argparse.Namespace) -> int:
try:
logger.info("ML Inference Advisor %s", __version__)
- if copy_profile_file(ctx, func_args):
+ if copy_profile_file(ctx, func_args, "target_profile"):
logger.info(
"\nThe target profile (.toml) is copied to the output directory: %s",
ctx.output_dir,
)
+ if copy_profile_file(ctx, func_args, "optimization_profile"):
+ logger.info(
+ "\nThe optimization profile (.toml) is copied to "
+ "the output directory: %s",
+ ctx.output_dir,
+ )
args.func(**func_args)
return 0
except KeyboardInterrupt:
@@ -278,11 +284,13 @@ def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) ->
return run_command(args)
-def copy_profile_file(ctx: ExecutionContext, func_args: dict) -> bool:
- """If present, copy the target profile file to the output directory."""
- if func_args.get("target_profile"):
+def copy_profile_file(
+ ctx: ExecutionContext, func_args: dict, profile_to_copy: str
+) -> bool:
+ """If present, copy the selected profile file to the output directory."""
+ if func_args.get(profile_to_copy):
return copy_profile_file_to_output_dir(
- func_args["target_profile"], ctx.output_dir
+ func_args[profile_to_copy], ctx.output_dir, profile_to_copy
)
return False
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 57f54dd..1c55fed 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
from __future__ import annotations
@@ -15,6 +15,7 @@ from mlia.core.common import AdviceCategory
from mlia.core.errors import ConfigurationError
from mlia.core.typing import OutputFormat
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.target.registry import builtin_optimization_names
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
@@ -130,6 +131,16 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
help="Ending node in the graph of the subgraph to be rewritten.",
)
+ optimization_profiles = builtin_optimization_names()
+ multi_optimization_group.add_argument(
+ "-o",
+ "--optimization-profile",
+ required=False,
+ default="optimization",
+ help="Built-in optimization profile or path to the custom profile. "
+ f"Built-in optimization profiles are {', '.join(optimization_profiles)}. ",
+ )
+
def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 6947206..20950cc 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for optimization selection."""
from __future__ import annotations
@@ -117,6 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: tf.keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
+ training_parameters: list[dict | None] | None = None,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -135,10 +136,14 @@ def get_optimizer(
return RewritingOptimizer(model, config)
if isinstance(config, OptimizationSettings):
- return _get_optimizer(model, cast(OptimizationSettings, config))
+ return _get_optimizer(
+ model, cast(OptimizationSettings, config), training_parameters
+ )
if is_list_of(config, OptimizationSettings):
- return _get_optimizer(model, cast(List[OptimizationSettings], config))
+ return _get_optimizer(
+ model, cast(List[OptimizationSettings], config), training_parameters
+ )
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -146,16 +151,18 @@ def get_optimizer(
def _get_optimizer(
model: tf.keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
+ training_parameters: list[dict | None] | None = None,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
optimizer_configs = []
+
for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings:
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize, dataset
+ opt_type, opt_target, layers_to_optimize, dataset, training_parameters
)
optimizer_configs.append(opt_config)
@@ -165,13 +172,23 @@ def _get_optimizer(
return MultiStageOptimizer(model, optimizer_configs)
-def _get_rewrite_train_params() -> TrainingParameters:
+def _get_rewrite_params(
+ training_parameters: list[dict | None] | None = None,
+) -> list:
"""Get the rewrite TrainingParameters.
Return the default constructed TrainingParameters() per default, but can be
overwritten in the unit tests.
"""
- return TrainingParameters()
+ if training_parameters is None:
+ return [TrainingParameters()]
+
+ if training_parameters[0] is None:
+ train_params = TrainingParameters()
+ else:
+ train_params = TrainingParameters(**training_parameters[0])
+
+ return [train_params]
def _get_optimizer_configuration(
@@ -179,6 +196,7 @@ def _get_optimizer_configuration(
optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
+ training_parameters: list[dict | None] | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -199,11 +217,12 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
+ rewrite_params = _get_rewrite_params(training_parameters)
return RewriteConfiguration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=_get_rewrite_train_params(),
+ train_params=rewrite_params[0],
)
raise ConfigurationError(
diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml
new file mode 100644
index 0000000..623a763
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization.toml
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[training]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 5f359c5..8c5d184 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Data collector support for performance optimizations."""
from __future__ import annotations
@@ -50,6 +50,8 @@ class OptimizingDataCollector(ContextAwareDataCollector):
optimizations = self._get_optimization_settings(self.context)
+ training_parameters = self._get_training_settings(self.context)
+
if not optimizations or optimizations == [[]]:
raise FunctionalityNotSupportedError(
reason="No optimization targets provided",
@@ -75,17 +77,22 @@ class OptimizingDataCollector(ContextAwareDataCollector):
model = self.model # type: ignore
optimizers: list[Callable] = [
- partial(self.optimize_model, opts) for opts in opt_settings
+ partial(self.optimize_model, opts, training_parameters)
+ for opts in opt_settings
]
return self.optimize_and_estimate_performance(model, optimizers, opt_settings)
def optimize_model(
- self, opt_settings: list[OptimizationSettings], model: KerasModel | TFLiteModel
+ self,
+ opt_settings: list[OptimizationSettings],
+ training_parameters: list[dict | None],
+ model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
- optimizer = get_optimizer(model, opt_settings)
-
+ optimizer = get_optimizer(
+ model, opt_settings, training_parameters=training_parameters
+ )
opts_as_str = ", ".join(str(opt) for opt in opt_settings)
logger.info("Applying model optimizations - [%s]", opts_as_str)
optimizer.apply_optimization()
@@ -116,6 +123,16 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
+ def _get_training_settings(self, context: Context) -> list[dict]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ OptimizingDataCollector.name(),
+ "training_parameters",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
@staticmethod
def _parse_optimization_params(
optimizations: list[list[dict]],
@@ -210,10 +227,19 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
if not is_list_of(optimization_targets, dict):
raise TypeError("Optimization targets value has wrong format.")
+ rewrite_parameters = extra_args.get("optimization_profile")
+ if not rewrite_parameters:
+ training_parameters = None
+ else:
+ if not isinstance(rewrite_parameters, dict):
+ raise TypeError("Training Parameter values has wrong format.")
+ training_parameters = extra_args["optimization_profile"].get("training")
+
advisor_parameters.update(
{
"common_optimizations": {
"optimizations": [optimization_targets],
+ "training_parameters": [training_parameters],
},
}
)
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 3bc74fa..8ccdad8 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target configuration module."""
from __future__ import annotations
@@ -22,9 +22,10 @@ from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
+from mlia.utils.filesystem import get_mlia_target_optimization_dir
-def get_builtin_profile_path(target_profile: str) -> Path:
+def get_builtin_target_profile_path(target_profile: str) -> Path:
"""
Construct the path to the built-in target profile file.
@@ -33,6 +34,15 @@ def get_builtin_profile_path(target_profile: str) -> Path:
return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+def get_builtin_optimization_profile_path(optimization_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
+
+ No checks are performed.
+ """
+ return get_mlia_target_optimization_dir() / f"{optimization_profile}.toml"
+
+
@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
@@ -56,11 +66,19 @@ def get_builtin_supported_profile_names() -> list[str]:
BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def is_builtin_profile(profile_name: str | Path) -> bool:
+def is_builtin_target_profile(profile_name: str | Path) -> bool:
"""Check if the given profile name belongs to a built-in profile."""
return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
+BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"]
+
+
+def is_builtin_optimization_profile(optimization_name: str | Path) -> bool:
+ """Check if the given optimization name belongs to a built-in optimization."""
+ return optimization_name in BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
+
+
T = TypeVar("T", bound="TargetProfile")
@@ -93,8 +111,8 @@ class TargetProfile(ABC):
@classmethod
def load_profile(cls: type[T], target_profile: str | Path) -> T:
"""Load a target profile from built-in target profile name or file path."""
- if is_builtin_profile(target_profile):
- profile_file = get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_target_profile(target_profile):
+ profile_file = get_builtin_target_profile_path(cast(str, target_profile))
else:
profile_file = Path(target_profile)
return cls.load(profile_file)
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index b7b6193..b850284 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target module."""
from __future__ import annotations
@@ -13,9 +13,12 @@ from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.reporting import Column
from mlia.core.reporting import Table
+from mlia.target.config import BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
-from mlia.target.config import get_builtin_profile_path
-from mlia.target.config import is_builtin_profile
+from mlia.target.config import get_builtin_optimization_profile_path
+from mlia.target.config import get_builtin_target_profile_path
+from mlia.target.config import is_builtin_optimization_profile
+from mlia.target.config import is_builtin_target_profile
from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
from mlia.target.config import TargetProfile
@@ -44,13 +47,18 @@ def builtin_profile_names() -> list[str]:
return BUILTIN_SUPPORTED_PROFILE_NAMES
+def builtin_optimization_names() -> list[str]:
+ """Return a list of built-in profile names (not file paths)."""
+ return BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
+
+
@lru_cache
def profile(target_profile: str | Path) -> TargetProfile:
"""Get the target profile data (built-in or custom file)."""
if not target_profile:
raise ValueError("No valid target profile was provided.")
- if is_builtin_profile(target_profile):
- profile_file = get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_target_profile(target_profile):
+ profile_file = get_builtin_target_profile_path(cast(str, target_profile))
profile_ = create_target_profile(profile_file)
else:
profile_file = Path(target_profile)
@@ -65,6 +73,27 @@ def profile(target_profile: str | Path) -> TargetProfile:
return profile_
+def get_optimization_profile(optimization_profile: str | Path) -> dict:
+ """Get the optimization profile data (built-in or custom file)."""
+ if not optimization_profile:
+ raise ValueError("No valid optimization profile was provided.")
+ if is_builtin_optimization_profile(optimization_profile):
+ profile_file = get_builtin_optimization_profile_path(
+ cast(str, optimization_profile)
+ )
+ profile_dict = load_profile(profile_file)
+ else:
+ profile_file = Path(optimization_profile)
+ if profile_file.is_file():
+ profile_dict = load_profile(profile_file)
+ else:
+ raise ValueError(
+ f"optimization Profile '{optimization_profile}' is neither a valid "
+ "built-in optimization profile name or a valid file path."
+ )
+ return profile_dict
+
+
def get_target(target_profile: str | Path) -> str:
"""Return target for the provided target_profile."""
return profile(target_profile).target
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
index e3ef7db..5348d06 100644
--- a/src/mlia/utils/filesystem.py
+++ b/src/mlia/utils/filesystem.py
@@ -34,6 +34,11 @@ def get_mlia_target_profiles_dir() -> Path:
return get_mlia_resources() / "target_profiles"
+def get_mlia_target_optimization_dir() -> Path:
+ """Get the profiles file."""
+ return get_mlia_resources() / "optimization_profiles"
+
+
@contextmanager
def temp_file(suffix: str | None = None) -> Generator[Path, None, None]:
"""Create temp file and remove it after."""