aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-25 13:05:32 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-04-12 14:08:07 +0000
commitec59b3c95106daebe2ce0e57592b2bf9e6562f54 (patch)
tree3a3861c73d5963cb8ef1d21dd6929e24123fc898
parent4de782fde8e38ec92bb5bc60e156de027f13bfba (diff)
downloadmlia-ec59b3c95106daebe2ce0e57592b2bf9e6562f54.tar.gz
fix: Change training_parameters to return empty list instead of list of None if needed.
Extension to MLIA-1004 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Ib40c2e5932c1210a1d141200815a76e33f5ab078
-rw-r--r--src/mlia/nn/select.py23
-rw-r--r--src/mlia/target/common/optimization.py13
-rw-r--r--tests/test_common_optimization.py18
-rw-r--r--tests/test_nn_select.py12
-rw-r--r--tests/test_target_cortex_a_advisor.py2
-rw-r--r--tests/test_target_tosa_advisor.py2
6 files changed, 32 insertions, 38 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 81a614f..b61e713 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -151,7 +151,7 @@ def get_optimizer(
def _get_optimizer(
model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -173,22 +173,17 @@ def _get_optimizer(
def _get_rewrite_params(
- training_parameters: list[dict | None] | None = None,
-) -> list:
+ training_parameters: dict | None = None,
+) -> TrainingParameters:
"""Get the rewrite TrainingParameters.
Return the default constructed TrainingParameters() per default, but can be
overwritten in the unit tests.
"""
- if training_parameters is None:
- return [TrainingParameters()]
+ if not training_parameters:
+ return TrainingParameters()
- if training_parameters[0] is None:
- train_params = TrainingParameters()
- else:
- train_params = TrainingParameters(**training_parameters[0])
-
- return [train_params]
+ return TrainingParameters(**training_parameters)
def _get_optimizer_configuration(
@@ -196,7 +191,7 @@ def _get_optimizer_configuration(
optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -222,7 +217,7 @@ def _get_optimizer_configuration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=rewrite_params[0],
+ train_params=rewrite_params,
)
raise ConfigurationError(
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 8c5d184..1423189 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -86,7 +86,7 @@ class OptimizingDataCollector(ContextAwareDataCollector):
def optimize_model(
self,
opt_settings: list[OptimizationSettings],
- training_parameters: list[dict | None],
+ training_parameters: dict | None,
model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
@@ -123,12 +123,12 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
- def _get_training_settings(self, context: Context) -> list[dict]:
+ def _get_training_settings(self, context: Context) -> dict:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
OptimizingDataCollector.name(),
"training_parameters",
- expected_type=list,
+ expected_type=dict,
expected=False,
context=context,
)
@@ -228,9 +228,8 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
raise TypeError("Optimization targets value has wrong format.")
rewrite_parameters = extra_args.get("optimization_profile")
- if not rewrite_parameters:
- training_parameters = None
- else:
+ training_parameters = None
+ if rewrite_parameters:
if not isinstance(rewrite_parameters, dict):
raise TypeError("Training Parameter values has wrong format.")
training_parameters = extra_args["optimization_profile"].get("training")
@@ -239,7 +238,7 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
{
"common_optimizations": {
"optimizations": [optimization_targets],
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
},
}
)
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 05a5b55..341e0d2 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -57,7 +57,7 @@ def test_optimizing_data_collector(
config_parameters={
"common_optimizations": {
"optimizations": optimizations,
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
}
}
)
@@ -94,7 +94,7 @@ def test_optimizing_data_collector(
collector.set_context(context)
collector.collect_data()
assert optimize_model_mock.call_args.args[0] == opt_settings[0]
- assert optimize_model_mock.call_args.args[1] == [training_parameters]
+ assert optimize_model_mock.call_args.args[1] == training_parameters
assert fake_optimizer.invocation_count == 1
@@ -158,10 +158,12 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
]
if not extra_args.get("optimization_profile"):
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == [None]
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ is None
+ )
else:
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == list(extra_args["optimization_profile"].values())
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ == extra_args["optimization_profile"]["training"]
+ )
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index aac07b4..4095076 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -183,11 +183,11 @@ def test_get_optimizer(
@pytest.mark.parametrize(
"rewrite_parameters",
- [[None], [{"batch_size": 64, "learning_rate": 0.003}]],
+ [None, {"batch_size": 64, "learning_rate": 0.003}],
)
@pytest.mark.skip_set_training_steps
def test_get_optimizer_training_parameters(
- rewrite_parameters: list[dict], test_tflite_model: Path
+ rewrite_parameters: dict | None, test_tflite_model: Path
) -> None:
"""Test function get_optimzer with various combinations of parameters."""
config = OptimizationSettings(
@@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters(
)
optimizer = cast(
RewritingOptimizer,
- get_optimizer(test_tflite_model, config, list(rewrite_parameters)),
+ get_optimizer(test_tflite_model, config, rewrite_parameters),
)
- assert len(rewrite_parameters) == 1
-
assert isinstance(
optimizer.optimizer_configuration.train_params, TrainingParameters
)
- if not rewrite_parameters[0]:
+ if not rewrite_parameters:
assert asdict(TrainingParameters()) == asdict(
optimizer.optimizer_configuration.train_params
)
else:
- assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict(
+ assert asdict(TrainingParameters()) | rewrite_parameters == asdict(
optimizer.optimizer_configuration.train_params
)
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 59d54b5..7bb57c3 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -47,7 +47,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
},
]
],
- "training_parameters": [None],
+ "training_parameters": None,
},
}
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index cc47321..020acc5 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -47,7 +47,7 @@ def test_configure_and_get_tosa_advisor(
},
]
],
- "training_parameters": [None],
+ "training_parameters": None,
},
"tosa_inference_advisor": {
"model": str(test_tflite_model),