aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/common/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/common/optimization.py')
-rw-r--r--src/mlia/target/common/optimization.py36
1 files changed, 31 insertions, 5 deletions
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 5f359c5..8c5d184 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Data collector support for performance optimizations."""
from __future__ import annotations
@@ -50,6 +50,8 @@ class OptimizingDataCollector(ContextAwareDataCollector):
optimizations = self._get_optimization_settings(self.context)
+ training_parameters = self._get_training_settings(self.context)
+
if not optimizations or optimizations == [[]]:
raise FunctionalityNotSupportedError(
reason="No optimization targets provided",
@@ -75,17 +77,22 @@ class OptimizingDataCollector(ContextAwareDataCollector):
model = self.model # type: ignore
optimizers: list[Callable] = [
- partial(self.optimize_model, opts) for opts in opt_settings
+ partial(self.optimize_model, opts, training_parameters)
+ for opts in opt_settings
]
return self.optimize_and_estimate_performance(model, optimizers, opt_settings)
def optimize_model(
- self, opt_settings: list[OptimizationSettings], model: KerasModel | TFLiteModel
+ self,
+ opt_settings: list[OptimizationSettings],
+ training_parameters: list[dict | None],
+ model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
- optimizer = get_optimizer(model, opt_settings)
-
+ optimizer = get_optimizer(
+ model, opt_settings, training_parameters=training_parameters
+ )
opts_as_str = ", ".join(str(opt) for opt in opt_settings)
logger.info("Applying model optimizations - [%s]", opts_as_str)
optimizer.apply_optimization()
@@ -116,6 +123,16 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
+ def _get_training_settings(self, context: Context) -> list[dict]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ OptimizingDataCollector.name(),
+ "training_parameters",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
@staticmethod
def _parse_optimization_params(
optimizations: list[list[dict]],
@@ -210,10 +227,19 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
if not is_list_of(optimization_targets, dict):
raise TypeError("Optimization targets value has wrong format.")
+ rewrite_parameters = extra_args.get("optimization_profile")
+ if not rewrite_parameters:
+ training_parameters = None
+ else:
+ if not isinstance(rewrite_parameters, dict):
+ raise TypeError("Training Parameter values has wrong format.")
+ training_parameters = extra_args["optimization_profile"].get("training")
+
advisor_parameters.update(
{
"common_optimizations": {
"optimizations": [optimization_targets],
+ "training_parameters": [training_parameters],
},
}
)