diff options
Diffstat (limited to 'src/mlia/target/ethos_u')
-rw-r--r-- | src/mlia/target/ethos_u/advice_generation.py | 20 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/data_analysis.py | 9 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/data_collection.py | 13 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/handlers.py | 4 | ||||
-rw-r--r-- | src/mlia/target/ethos_u/reporters.py | 34 |
5 files changed, 66 insertions, 14 deletions
diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py index daae4f4..a9f9eac 100644 --- a/src/mlia/target/ethos_u/advice_generation.py +++ b/src/mlia/target/ethos_u/advice_generation.py @@ -12,6 +12,10 @@ from mlia.core.advice_generation import FactBasedAdviceProducer from mlia.core.common import AdviceCategory from mlia.core.common import DataItem from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common +from mlia.target.common.reporters import handle_tflite_check_failed_common +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.ethos_u.data_analysis import AllOperatorsSupportedOnNPU from mlia.target.ethos_u.data_analysis import HasCPUOnlyOperators from mlia.target.ethos_u.data_analysis import HasUnsupportedOnNPUOperators @@ -147,6 +151,22 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): ] ) + @produce_advice.register + @advice_category(AdviceCategory.COMPATIBILITY) + def handle_model_is_not_tflite_compatible( + self, data_item: ModelIsNotTFLiteCompatible + ) -> None: + """Advice for TensorFlow Lite compatibility.""" + handle_model_is_not_tflite_compatible_common(self, data_item) + + @produce_advice.register + @advice_category(AdviceCategory.COMPATIBILITY) + def handle_tflite_check_failed( + self, _data_item: TFLiteCompatibilityCheckFailed + ) -> None: + """Advice for the failed TensorFlow Lite compatibility checks.""" + handle_tflite_check_failed_common(self, _data_item) + @staticmethod def get_next_optimization_targets( opt_type: list[OptimizationSettings], diff --git a/src/mlia/target/ethos_u/data_analysis.py b/src/mlia/target/ethos_u/data_analysis.py index 6b66734..3df4bff 100644 --- a/src/mlia/target/ethos_u/data_analysis.py +++ b/src/mlia/target/ethos_u/data_analysis.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U data analysis module.""" from __future__ import annotations @@ -11,6 +11,8 @@ from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import analyze_tflite_compatibility_common from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics @@ -151,3 +153,8 @@ class EthosUDataAnalyzer(FactExtractor): diffs.append(diff) self.add_fact(OptimizationResults(diffs)) + + @analyze_data.register + def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None: + """Analyze TensorFlow Lite compatibility information.""" + analyze_tflite_compatibility_common(self, data_item) diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py index 4fdfe96..8348393 100644 --- a/src/mlia/target/ethos_u/data_collection.py +++ b/src/mlia/target/ethos_u/data_collection.py @@ -17,6 +17,9 @@ from mlia.nn.tensorflow.config import get_tflite_model from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.optimizations.select import get_optimizer from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.tflite_compat import TFLiteChecker +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import save_keras_model from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import EthosUPerformanceEstimator @@ -36,8 +39,16 @@ class EthosUOperatorCompatibility(ContextAwareDataCollector): self.model = model self.target_config = target_config - def collect_data(self) -> Operators: + def collect_data(self) -> Operators | TFLiteCompatibilityInfo | None: """Collect operator compatibility information.""" + if not is_tflite_model(self.model): + with log_action("Checking TensorFlow Lite compatibility ..."): + tflite_checker = TFLiteChecker() + tflite_compat = tflite_checker.check_compatibility(self.model) + + if not tflite_compat.compatible: + return tflite_compat + tflite_model = get_tflite_model(self.model, self.context) with log_action("Checking operator compatibility ..."): diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py index b9c89e8..1b15c55 100644 --- a/src/mlia/target/ethos_u/handlers.py +++ b/src/mlia/target/ethos_u/handlers.py @@ -8,6 +8,7 @@ import logging from mlia.backend.vela.compat import Operators from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.target.ethos_u.events import EthosUAdvisorEventHandler from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics @@ -49,6 +50,9 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): space=True, ) + if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible: + self.reporter.submit(data_item, delay_print=True) + def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None: """Handle EthosUAdvisorStarted event.""" self.reporter.submit(event.target_config) diff --git a/src/mlia/target/ethos_u/reporters.py b/src/mlia/target/ethos_u/reporters.py index 2a5b5d3..4964462 100644 --- a/src/mlia/target/ethos_u/reporters.py +++ b/src/mlia/target/ethos_u/reporters.py @@ -23,6 +23,8 @@ from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import SingleRow from mlia.core.reporting import Table +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import report_tflite_compatiblity from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import PerformanceMetrics from mlia.utils.console import style_improvement @@ -363,23 +365,31 @@ def report_perf_metrics( def ethos_u_formatters(data: Any) -> Callable[[Any], Report]: """Find appropriate formatter for the provided data.""" + report: Callable[[Any], Report] | None = None + if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2): - return report_perf_metrics + report = report_perf_metrics - if is_list_of(data, Advice): - return report_advice + elif is_list_of(data, Advice): + report = report_advice - if is_list_of(data, Operator): - return report_operators + elif is_list_of(data, Operator): + report = report_operators - if isinstance(data, Operators): - return report_operators_stat + elif isinstance(data, Operators): + report = report_operators_stat - if isinstance(data, EthosUConfiguration): - return report_target_details + elif isinstance(data, EthosUConfiguration): + report = report_target_details - if isinstance(data, (list, tuple)): + elif isinstance(data, (list, tuple)): formatters = [ethos_u_formatters(item) for item in data] - return CompoundFormatter(formatters) + report = CompoundFormatter(formatters) + + elif isinstance(data, TFLiteCompatibilityInfo): + report = report_tflite_compatiblity + + else: + raise Exception(f"Unable to find appropriate formatter for {data}") - raise Exception(f"Unable to find appropriate formatter for {data}") + return report # type: ignore |