diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-25 13:05:32 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-04-12 14:08:07 +0000 |
commit | ec59b3c95106daebe2ce0e57592b2bf9e6562f54 (patch) | |
tree | 3a3861c73d5963cb8ef1d21dd6929e24123fc898 | |
parent | 4de782fde8e38ec92bb5bc60e156de027f13bfba (diff) | |
download | mlia-ec59b3c95106daebe2ce0e57592b2bf9e6562f54.tar.gz |
fix: Change training_parameters to return empty list instead of list of None if needed.
Extension to MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: Ib40c2e5932c1210a1d141200815a76e33f5ab078
-rw-r--r-- | src/mlia/nn/select.py | 23 | ||||
-rw-r--r-- | src/mlia/target/common/optimization.py | 13 | ||||
-rw-r--r-- | tests/test_common_optimization.py | 18 | ||||
-rw-r--r-- | tests/test_nn_select.py | 12 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advisor.py | 2 | ||||
-rw-r--r-- | tests/test_target_tosa_advisor.py | 2 |
6 files changed, 32 insertions, 38 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 81a614f..b61e713 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -173,22 +173,17 @@ def _get_optimizer( def _get_rewrite_params( - training_parameters: list[dict | None] | None = None, -) -> list: + training_parameters: dict | None = None, +) -> TrainingParameters: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - if training_parameters is None: - return [TrainingParameters()] + if not training_parameters: + return TrainingParameters() - if training_parameters[0] is None: - train_params = TrainingParameters() - else: - train_params = TrainingParameters(**training_parameters[0]) - - return [train_params] + return TrainingParameters(**training_parameters) def _get_optimizer_configuration( @@ -196,7 +191,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -222,7 +217,7 @@ def _get_optimizer_configuration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params[0], + train_params=rewrite_params, ) raise ConfigurationError( diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index 8c5d184..1423189 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -86,7 +86,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): def optimize_model( self, opt_settings: list[OptimizationSettings], - training_parameters: list[dict | None], + training_parameters: dict | None, model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" @@ -123,12 +123,12 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) - def _get_training_settings(self, context: Context) -> list[dict]: + def _get_training_settings(self, context: Context) -> dict: """Get optimization settings.""" return self.get_parameter( # type: ignore OptimizingDataCollector.name(), "training_parameters", - expected_type=list, + expected_type=dict, expected=False, context=context, ) @@ -228,9 +228,8 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - raise TypeError("Optimization targets value has wrong format.") rewrite_parameters = extra_args.get("optimization_profile") - if not rewrite_parameters: - training_parameters = None - else: + training_parameters = None + if rewrite_parameters: if not isinstance(rewrite_parameters, dict): raise TypeError("Training Parameter values has wrong format.") training_parameters = extra_args["optimization_profile"].get("training") @@ -239,7 +238,7 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - { "common_optimizations": { "optimizations": [optimization_targets], - "training_parameters": [training_parameters], + "training_parameters": training_parameters, }, } ) diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 05a5b55..341e0d2 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -57,7 +57,7 @@ def test_optimizing_data_collector( config_parameters={ "common_optimizations": { "optimizations": optimizations, - "training_parameters": [training_parameters], + "training_parameters": training_parameters, } } ) @@ -94,7 +94,7 @@ def test_optimizing_data_collector( collector.set_context(context) collector.collect_data() assert optimize_model_mock.call_args.args[0] == opt_settings[0] - assert optimize_model_mock.call_args.args[1] == [training_parameters] + assert optimize_model_mock.call_args.args[1] == training_parameters assert fake_optimizer.invocation_count == 1 @@ -158,10 +158,12 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - ] if not extra_args.get("optimization_profile"): - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == [None] + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + is None + ) else: - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == list(extra_args["optimization_profile"].values()) + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + == extra_args["optimization_profile"]["training"] + ) diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index aac07b4..4095076 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -183,11 +183,11 @@ def test_get_optimizer( @pytest.mark.parametrize( "rewrite_parameters", - [[None], [{"batch_size": 64, "learning_rate": 0.003}]], + [None, {"batch_size": 64, "learning_rate": 0.003}], ) @pytest.mark.skip_set_training_steps def test_get_optimizer_training_parameters( - rewrite_parameters: list[dict], test_tflite_model: Path + rewrite_parameters: dict | None, test_tflite_model: Path ) -> None: """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( @@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters( ) optimizer = cast( RewritingOptimizer, - get_optimizer(test_tflite_model, config, list(rewrite_parameters)), + get_optimizer(test_tflite_model, config, rewrite_parameters), ) - assert len(rewrite_parameters) == 1 - assert isinstance( optimizer.optimizer_configuration.train_params, TrainingParameters ) - if not rewrite_parameters[0]: + if not rewrite_parameters: assert asdict(TrainingParameters()) == asdict( optimizer.optimizer_configuration.train_params ) else: - assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict( + assert asdict(TrainingParameters()) | rewrite_parameters == asdict( optimizer.optimizer_configuration.train_params ) diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 59d54b5..7bb57c3 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: }, ] ], - "training_parameters": [None], + "training_parameters": None, }, } diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index cc47321..020acc5 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_tosa_advisor( }, ] ], - "training_parameters": [None], + "training_parameters": None, }, "tosa_inference_advisor": { "model": str(test_tflite_model), |