aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/ethos_u/data_collection.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/ethos_u/data_collection.py')
-rw-r--r--src/mlia/target/ethos_u/data_collection.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index 4fdfe96..8348393 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -17,6 +17,9 @@ from mlia.nn.tensorflow.config import get_tflite_model
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.optimizations.select import get_optimizer
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import EthosUPerformanceEstimator
@@ -36,8 +39,16 @@ class EthosUOperatorCompatibility(ContextAwareDataCollector):
self.model = model
self.target_config = target_config
- def collect_data(self) -> Operators:
+ def collect_data(self) -> Operators | TFLiteCompatibilityInfo | None:
"""Collect operator compatibility information."""
+ if not is_tflite_model(self.model):
+ with log_action("Checking TensorFlow Lite compatibility ..."):
+ tflite_checker = TFLiteChecker()
+ tflite_compat = tflite_checker.check_compatibility(self.model)
+
+ if not tflite_compat.compatible:
+ return tflite_compat
+
tflite_model = get_tflite_model(self.model, self.context)
with log_action("Checking operator compatibility ..."):