aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/ethos_u
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/ethos_u')
-rw-r--r--src/mlia/target/ethos_u/advice_generation.py20
-rw-r--r--src/mlia/target/ethos_u/data_analysis.py9
-rw-r--r--src/mlia/target/ethos_u/data_collection.py13
-rw-r--r--src/mlia/target/ethos_u/handlers.py4
-rw-r--r--src/mlia/target/ethos_u/reporters.py34
5 files changed, 66 insertions, 14 deletions
diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py
index daae4f4..a9f9eac 100644
--- a/src/mlia/target/ethos_u/advice_generation.py
+++ b/src/mlia/target/ethos_u/advice_generation.py
@@ -12,6 +12,10 @@ from mlia.core.advice_generation import FactBasedAdviceProducer
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common
+from mlia.target.common.reporters import handle_tflite_check_failed_common
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.ethos_u.data_analysis import AllOperatorsSupportedOnNPU
from mlia.target.ethos_u.data_analysis import HasCPUOnlyOperators
from mlia.target.ethos_u.data_analysis import HasUnsupportedOnNPUOperators
@@ -147,6 +151,22 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
]
)
+ @produce_advice.register
+ @advice_category(AdviceCategory.COMPATIBILITY)
+ def handle_model_is_not_tflite_compatible(
+ self, data_item: ModelIsNotTFLiteCompatible
+ ) -> None:
+ """Advice for TensorFlow Lite compatibility."""
+ handle_model_is_not_tflite_compatible_common(self, data_item)
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.COMPATIBILITY)
+ def handle_tflite_check_failed(
+ self, _data_item: TFLiteCompatibilityCheckFailed
+ ) -> None:
+ """Advice for the failed TensorFlow Lite compatibility checks."""
+ handle_tflite_check_failed_common(self, _data_item)
+
@staticmethod
def get_next_optimization_targets(
opt_type: list[OptimizationSettings],
diff --git a/src/mlia/target/ethos_u/data_analysis.py b/src/mlia/target/ethos_u/data_analysis.py
index 6b66734..3df4bff 100644
--- a/src/mlia/target/ethos_u/data_analysis.py
+++ b/src/mlia/target/ethos_u/data_analysis.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U data analysis module."""
from __future__ import annotations
@@ -11,6 +11,8 @@ from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import analyze_tflite_compatibility_common
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
@@ -151,3 +153,8 @@ class EthosUDataAnalyzer(FactExtractor):
diffs.append(diff)
self.add_fact(OptimizationResults(diffs))
+
+ @analyze_data.register
+ def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
+ """Analyze TensorFlow Lite compatibility information."""
+ analyze_tflite_compatibility_common(self, data_item)
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index 4fdfe96..8348393 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -17,6 +17,9 @@ from mlia.nn.tensorflow.config import get_tflite_model
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.optimizations.select import get_optimizer
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import EthosUPerformanceEstimator
@@ -36,8 +39,16 @@ class EthosUOperatorCompatibility(ContextAwareDataCollector):
self.model = model
self.target_config = target_config
- def collect_data(self) -> Operators:
+ def collect_data(self) -> Operators | TFLiteCompatibilityInfo | None:
"""Collect operator compatibility information."""
+ if not is_tflite_model(self.model):
+ with log_action("Checking TensorFlow Lite compatibility ..."):
+ tflite_checker = TFLiteChecker()
+ tflite_compat = tflite_checker.check_compatibility(self.model)
+
+ if not tflite_compat.compatible:
+ return tflite_compat
+
tflite_model = get_tflite_model(self.model, self.context)
with log_action("Checking operator compatibility ..."):
diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py
index b9c89e8..1b15c55 100644
--- a/src/mlia/target/ethos_u/handlers.py
+++ b/src/mlia/target/ethos_u/handlers.py
@@ -8,6 +8,7 @@ import logging
from mlia.backend.vela.compat import Operators
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.target.ethos_u.events import EthosUAdvisorEventHandler
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
@@ -49,6 +50,9 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
space=True,
)
+ if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible:
+ self.reporter.submit(data_item, delay_print=True)
+
def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
"""Handle EthosUAdvisorStarted event."""
self.reporter.submit(event.target_config)
diff --git a/src/mlia/target/ethos_u/reporters.py b/src/mlia/target/ethos_u/reporters.py
index 2a5b5d3..4964462 100644
--- a/src/mlia/target/ethos_u/reporters.py
+++ b/src/mlia/target/ethos_u/reporters.py
@@ -23,6 +23,8 @@ from mlia.core.reporting import Report
from mlia.core.reporting import ReportItem
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import report_tflite_compatiblity
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import PerformanceMetrics
from mlia.utils.console import style_improvement
@@ -363,23 +365,31 @@ def report_perf_metrics(
def ethos_u_formatters(data: Any) -> Callable[[Any], Report]:
"""Find appropriate formatter for the provided data."""
+ report: Callable[[Any], Report] | None = None
+
if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2):
- return report_perf_metrics
+ report = report_perf_metrics
- if is_list_of(data, Advice):
- return report_advice
+ elif is_list_of(data, Advice):
+ report = report_advice
- if is_list_of(data, Operator):
- return report_operators
+ elif is_list_of(data, Operator):
+ report = report_operators
- if isinstance(data, Operators):
- return report_operators_stat
+ elif isinstance(data, Operators):
+ report = report_operators_stat
- if isinstance(data, EthosUConfiguration):
- return report_target_details
+ elif isinstance(data, EthosUConfiguration):
+ report = report_target_details
- if isinstance(data, (list, tuple)):
+ elif isinstance(data, (list, tuple)):
formatters = [ethos_u_formatters(item) for item in data]
- return CompoundFormatter(formatters)
+ report = CompoundFormatter(formatters)
+
+ elif isinstance(data, TFLiteCompatibilityInfo):
+ report = report_tflite_compatiblity
+
+ else:
+ raise Exception(f"Unable to find appropriate formatter for {data}")
- raise Exception(f"Unable to find appropriate formatter for {data}")
+ return report # type: ignore