diff options
Diffstat (limited to 'src/mlia/target')
-rw-r--r-- | src/mlia/target/ethos_u/data_collection.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py index ba8b0fe..4ea6120 100644 --- a/src/mlia/target/ethos_u/data_collection.py +++ b/src/mlia/target/ethos_u/data_collection.py @@ -106,15 +106,14 @@ class OptimizeModel: self.context = context self.opt_settings = opt_settings - def __call__(self, keras_model: KerasModel) -> Any: + def __call__(self, model: KerasModel | TFLiteModel) -> Any: """Run optimization.""" - optimizer = get_optimizer(keras_model, self.opt_settings) + optimizer = get_optimizer(model, self.opt_settings) opts_as_str = ", ".join(str(opt) for opt in self.opt_settings) logger.info("Applying model optimizations - [%s]", opts_as_str) optimizer.apply_optimization() - - model = optimizer.get_model() + model = optimizer.get_model() # type: ignore if isinstance(model, Path): return model @@ -178,6 +177,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector): self.target, self.backends, ) + original_metrics, *optimized_metrics = estimate_performance( model, estimator, optimizers # type: ignore ) |