diff options
Diffstat (limited to 'src/mlia/target/ethos_u/data_collection.py')
-rw-r--r-- | src/mlia/target/ethos_u/data_collection.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py index 4fdfe96..8348393 100644 --- a/src/mlia/target/ethos_u/data_collection.py +++ b/src/mlia/target/ethos_u/data_collection.py @@ -17,6 +17,9 @@ from mlia.nn.tensorflow.config import get_tflite_model from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.optimizations.select import get_optimizer from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.tflite_compat import TFLiteChecker +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import save_keras_model from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import EthosUPerformanceEstimator @@ -36,8 +39,16 @@ class EthosUOperatorCompatibility(ContextAwareDataCollector): self.model = model self.target_config = target_config - def collect_data(self) -> Operators: + def collect_data(self) -> Operators | TFLiteCompatibilityInfo | None: """Collect operator compatibility information.""" + if not is_tflite_model(self.model): + with log_action("Checking TensorFlow Lite compatibility ..."): + tflite_checker = TFLiteChecker() + tflite_compat = tflite_checker.check_compatibility(self.model) + + if not tflite_compat.compatible: + return tflite_compat + tflite_model = get_tflite_model(self.model, self.context) with log_action("Checking operator compatibility ..."): |