diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-25 13:05:32 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-04-12 14:08:07 +0000 |
commit | ec59b3c95106daebe2ce0e57592b2bf9e6562f54 (patch) | |
tree | 3a3861c73d5963cb8ef1d21dd6929e24123fc898 /src/mlia/target/common/optimization.py | |
parent | 4de782fde8e38ec92bb5bc60e156de027f13bfba (diff) | |
download | mlia-ec59b3c95106daebe2ce0e57592b2bf9e6562f54.tar.gz |
fix: Change training_parameters to return empty list instead of list of None if needed.
Extension to MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: Ib40c2e5932c1210a1d141200815a76e33f5ab078
Diffstat (limited to 'src/mlia/target/common/optimization.py')
-rw-r--r-- | src/mlia/target/common/optimization.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index 8c5d184..1423189 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -86,7 +86,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): def optimize_model( self, opt_settings: list[OptimizationSettings], - training_parameters: list[dict | None], + training_parameters: dict | None, model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" @@ -123,12 +123,12 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) - def _get_training_settings(self, context: Context) -> list[dict]: + def _get_training_settings(self, context: Context) -> dict: """Get optimization settings.""" return self.get_parameter( # type: ignore OptimizingDataCollector.name(), "training_parameters", - expected_type=list, + expected_type=dict, expected=False, context=context, ) @@ -228,9 +228,8 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - raise TypeError("Optimization targets value has wrong format.") rewrite_parameters = extra_args.get("optimization_profile") - if not rewrite_parameters: - training_parameters = None - else: + training_parameters = None + if rewrite_parameters: if not isinstance(rewrite_parameters, dict): raise TypeError("Training Parameter values has wrong format.") training_parameters = extra_args["optimization_profile"].get("training") @@ -239,7 +238,7 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - { "common_optimizations": { "optimizations": [optimization_targets], - "training_parameters": [training_parameters], + "training_parameters": training_parameters, }, } ) |