aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/ethos_u/data_collection.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/ethos_u/data_collection.py')
-rw-r--r--src/mlia/target/ethos_u/data_collection.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index ba8b0fe..4ea6120 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -106,15 +106,14 @@ class OptimizeModel:
self.context = context
self.opt_settings = opt_settings
- def __call__(self, keras_model: KerasModel) -> Any:
+ def __call__(self, model: KerasModel | TFLiteModel) -> Any:
"""Run optimization."""
- optimizer = get_optimizer(keras_model, self.opt_settings)
+ optimizer = get_optimizer(model, self.opt_settings)
opts_as_str = ", ".join(str(opt) for opt in self.opt_settings)
logger.info("Applying model optimizations - [%s]", opts_as_str)
optimizer.apply_optimization()
-
- model = optimizer.get_model()
+ model = optimizer.get_model() # type: ignore
if isinstance(model, Path):
return model
@@ -178,6 +177,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
self.target,
self.backends,
)
+
original_metrics, *optimized_metrics = estimate_performance(
model, estimator, optimizers # type: ignore
)