aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDhruv Chauhan <dhruv.chauhan@arm.com>2023-03-20 10:22:08 +0000
committerDhruv Chauhan <dhruv.chauhan@arm.com>2023-03-24 12:44:23 +0000
commitedf436c48029aa4e2b4ca5d17eee5a8f07ecbd6f (patch)
tree2375038e77873f6cd499b8938bc8b816daea3fc8
parent803a91c0723533f62148528a81f9d0411b57438e (diff)
downloadmlia-edf436c48029aa4e2b4ca5d17eee5a8f07ecbd6f.tar.gz
MLIA-711 Extend TensorFlow Lite Compatibility Check
- Unify the TensorFlow Lite compatibility check across Cortex-A, TOSA and Ethos-U targets - Display tables/messages with parsed information - Do not display raw TensorFlow Lite errors, and return with exit code 0 Change-Id: I9333fdb6cbe592f1ed7395d392412168492a1479
-rw-r--r--src/mlia/target/common/reporters.py159
-rw-r--r--src/mlia/target/cortex_a/advice_generation.py53
-rw-r--r--src/mlia/target/cortex_a/data_analysis.py36
-rw-r--r--src/mlia/target/cortex_a/reporters.py54
-rw-r--r--src/mlia/target/ethos_u/advice_generation.py20
-rw-r--r--src/mlia/target/ethos_u/data_analysis.py9
-rw-r--r--src/mlia/target/ethos_u/data_collection.py13
-rw-r--r--src/mlia/target/ethos_u/handlers.py4
-rw-r--r--src/mlia/target/ethos_u/reporters.py34
-rw-r--r--src/mlia/target/tosa/advice_generation.py20
-rw-r--r--src/mlia/target/tosa/data_analysis.py11
-rw-r--r--src/mlia/target/tosa/data_collection.py17
-rw-r--r--src/mlia/target/tosa/handlers.py4
-rw-r--r--src/mlia/target/tosa/reporters.py5
-rw-r--r--tests/test_target_cortex_a_advice_generation.py6
-rw-r--r--tests/test_target_cortex_a_data_analysis.py6
-rw-r--r--tests/test_target_ethos_u_data_analysis.py52
-rw-r--r--tests/test_target_ethos_u_reporters.py24
-rw-r--r--tests/test_target_tosa_data_analysis.py54
19 files changed, 425 insertions, 156 deletions
diff --git a/src/mlia/target/common/reporters.py b/src/mlia/target/common/reporters.py
new file mode 100644
index 0000000..366e154
--- /dev/null
+++ b/src/mlia/target/common/reporters.py
@@ -0,0 +1,159 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common reports module."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from mlia.core.data_analysis import Fact
+from mlia.core.reporting import Column
+from mlia.core.reporting import Format
+from mlia.core.reporting import Report
+from mlia.core.reporting import Table
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+
+
+@dataclass
+class ModelIsNotTFLiteCompatible(Fact):
+ """Model could not be converted into TensorFlow Lite format."""
+
+ custom_ops: list[str] | None = None
+ flex_ops: list[str] | None = None
+
+
+@dataclass
+class TFLiteCompatibilityCheckFailed(Fact):
+ """TensorFlow Lite compatibility check failed by unknown reason."""
+
+
+@dataclass
+class ModelHasCustomOperators(Fact):
+ """Model could not be loaded because it contains custom ops."""
+
+
+def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report:
+ """Generate report for the TensorFlow Lite compatibility information."""
+ if compat_info.conversion_errors:
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column("Operator", alias="operator"),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=25),
+ ),
+ Column("Error code", alias="error_code"),
+ Column(
+ "Error message", alias="error_message", fmt=Format(wrap_width=25)
+ ),
+ ],
+ [
+ (
+ index + 1,
+ err.operator,
+ ", ".join(err.location),
+ err.code.name,
+ err.message,
+ )
+ for index, err in enumerate(compat_info.conversion_errors)
+ ],
+ name="TensorFlow Lite conversion errors",
+ alias="tensorflow_lite_conversion_errors",
+ )
+
+ return Table(
+ columns=[
+ Column("Reason", alias="reason"),
+ Column(
+ "Exception details",
+ alias="exception_details",
+ fmt=Format(wrap_width=40),
+ ),
+ ],
+ rows=[
+ (
+ "TensorFlow Lite compatibility check failed with exception",
+ str(compat_info.conversion_exception),
+ ),
+ ],
+ name="TensorFlow Lite compatibility errors",
+ alias="tflite_compatibility",
+ )
+
+
+def handle_model_is_not_tflite_compatible_common( # type: ignore
+ self, data_item: ModelIsNotTFLiteCompatible
+) -> None:
+ """Advice for TensorFlow Lite compatibility."""
+ if data_item.flex_ops:
+ self.add_advice(
+ [
+ "The following operators are not natively "
+ "supported by TensorFlow Lite: "
+ f"{', '.join(data_item.flex_ops)}.",
+ "Using select TensorFlow operators in TensorFlow Lite model "
+ "requires special initialization of TFLiteConverter and "
+ "TensorFlow Lite run-time.",
+ "Please refer to the TensorFlow documentation for more "
+ "details: https://www.tensorflow.org/lite/guide/ops_select",
+ "Note, such models are not supported by the ML Inference Advisor.",
+ ]
+ )
+
+ if data_item.custom_ops:
+ self.add_advice(
+ [
+ "The following operators appear to be custom and not natively "
+ "supported by TensorFlow Lite: "
+ f"{', '.join(data_item.custom_ops)}.",
+ "Using custom operators in TensorFlow Lite model "
+ "requires special initialization of TFLiteConverter and "
+ "TensorFlow Lite run-time.",
+ "Please refer to the TensorFlow documentation for more "
+ "details: https://www.tensorflow.org/lite/guide/ops_custom",
+ "Note, such models are not supported by the ML Inference Advisor.",
+ ]
+ )
+
+ if not data_item.flex_ops and not data_item.custom_ops:
+ self.add_advice(
+ [
+ "Model could not be converted into TensorFlow Lite format.",
+ "Please refer to the table for more details.",
+ ]
+ )
+
+
+def handle_tflite_check_failed_common( # type: ignore
+ self, _data_item: TFLiteCompatibilityCheckFailed
+) -> None:
+ """Advice for the failed TensorFlow Lite compatibility checks."""
+ self.add_advice(
+ [
+ "Model could not be converted into TensorFlow Lite format.",
+ "Please refer to the table for more details.",
+ ]
+ )
+
+
+def analyze_tflite_compatibility_common( # type: ignore
+ self, data_item: TFLiteCompatibilityInfo
+) -> None:
+ """Analyze TensorFlow Lite compatibility information."""
+ if data_item.compatible:
+ return
+
+ if data_item.conversion_failed_with_errors:
+ self.add_fact(
+ ModelIsNotTFLiteCompatible(
+ custom_ops=data_item.required_custom_ops,
+ flex_ops=data_item.required_flex_ops,
+ )
+ )
+
+ if data_item.check_failed_with_unknown_error:
+ self.add_fact(TFLiteCompatibilityCheckFailed())
+
+ if data_item.conversion_failed_for_model_with_custom_ops:
+ self.add_fact(ModelHasCustomOperators())
diff --git a/src/mlia/target/cortex_a/advice_generation.py b/src/mlia/target/cortex_a/advice_generation.py
index 98e8c06..1011d6c 100644
--- a/src/mlia/target/cortex_a/advice_generation.py
+++ b/src/mlia/target/cortex_a/advice_generation.py
@@ -7,11 +7,13 @@ from mlia.core.advice_generation import advice_category
from mlia.core.advice_generation import FactBasedAdviceProducer
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
-from mlia.target.cortex_a.data_analysis import ModelHasCustomOperators
+from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common
+from mlia.target.common.reporters import handle_tflite_check_failed_common
+from mlia.target.common.reporters import ModelHasCustomOperators
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible
-from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible
-from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed
class CortexAAdviceProducer(FactBasedAdviceProducer):
@@ -88,43 +90,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
self, data_item: ModelIsNotTFLiteCompatible
) -> None:
"""Advice for TensorFlow Lite compatibility."""
- if data_item.flex_ops:
- self.add_advice(
- [
- "The following operators are not natively "
- "supported by TensorFlow Lite: "
- f"{', '.join(data_item.flex_ops)}.",
- "Using select TensorFlow operators in TensorFlow Lite model "
- "requires special initialization of TFLiteConverter and "
- "TensorFlow Lite run-time.",
- "Please refer to the TensorFlow documentation for more "
- "details: https://www.tensorflow.org/lite/guide/ops_select",
- "Note, such models are not supported by the ML Inference Advisor.",
- ]
- )
-
- if data_item.custom_ops:
- self.add_advice(
- [
- "The following operators appear to be custom and not natively "
- "supported by TensorFlow Lite: "
- f"{', '.join(data_item.custom_ops)}.",
- "Using custom operators in TensorFlow Lite model "
- "requires special initialization of TFLiteConverter and "
- "TensorFlow Lite run-time.",
- "Please refer to the TensorFlow documentation for more "
- "details: https://www.tensorflow.org/lite/guide/ops_custom",
- "Note, such models are not supported by the ML Inference Advisor.",
- ]
- )
-
- if not data_item.flex_ops and not data_item.custom_ops:
- self.add_advice(
- [
- "Model could not be converted into TensorFlow Lite format.",
- "Please refer to the table for more details.",
- ]
- )
+ handle_model_is_not_tflite_compatible_common(self, data_item)
@produce_advice.register
@advice_category(AdviceCategory.COMPATIBILITY)
@@ -132,12 +98,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
self, _data_item: TFLiteCompatibilityCheckFailed
) -> None:
"""Advice for the failed TensorFlow Lite compatibility checks."""
- self.add_advice(
- [
- "Model could not be converted into TensorFlow Lite format.",
- "Please refer to the table for more details.",
- ]
- )
+ handle_tflite_check_failed_common(self, _data_item)
@produce_advice.register
@advice_category(AdviceCategory.COMPATIBILITY)
diff --git a/src/mlia/target/cortex_a/data_analysis.py b/src/mlia/target/cortex_a/data_analysis.py
index 089c1a2..3161618 100644
--- a/src/mlia/target/cortex_a/data_analysis.py
+++ b/src/mlia/target/cortex_a/data_analysis.py
@@ -12,6 +12,7 @@ from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import analyze_tflite_compatibility_common
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
@@ -64,22 +65,7 @@ class CortexADataAnalyzer(FactExtractor):
@analyze_data.register
def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
"""Analyze TensorFlow Lite compatibility information."""
- if data_item.compatible:
- return
-
- if data_item.conversion_failed_with_errors:
- self.add_fact(
- ModelIsNotTFLiteCompatible(
- custom_ops=data_item.required_custom_ops,
- flex_ops=data_item.required_flex_ops,
- )
- )
-
- if data_item.check_failed_with_unknown_error:
- self.add_fact(TFLiteCompatibilityCheckFailed())
-
- if data_item.conversion_failed_for_model_with_custom_ops:
- self.add_fact(ModelHasCustomOperators())
+ analyze_tflite_compatibility_common(self, data_item)
@dataclass
@@ -107,21 +93,3 @@ class ModelIsNotCortexACompatible(CortexACompatibility):
unsupported_ops: set[str]
activation_func_support: dict[str, ActivationFunctionSupport]
-
-
-@dataclass
-class ModelIsNotTFLiteCompatible(Fact):
- """Model could not be converted into TensorFlow Lite format."""
-
- custom_ops: list[str] | None = None
- flex_ops: list[str] | None = None
-
-
-@dataclass
-class TFLiteCompatibilityCheckFailed(Fact):
- """TensorFlow Lite compatibility check failed by unknown reason."""
-
-
-@dataclass
-class ModelHasCustomOperators(Fact):
- """Model could not be loaded because it contains custom ops."""
diff --git a/src/mlia/target/cortex_a/reporters.py b/src/mlia/target/cortex_a/reporters.py
index 65d7906..fc80c9f 100644
--- a/src/mlia/target/cortex_a/reporters.py
+++ b/src/mlia/target/cortex_a/reporters.py
@@ -17,6 +17,7 @@ from mlia.core.reporting import Report
from mlia.core.reporting import ReportItem
from mlia.core.reporting import Table
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import report_tflite_compatiblity
from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.utils.console import style_improvement
@@ -34,57 +35,6 @@ def report_target(target_config: CortexAConfiguration) -> Report:
)
-def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report:
- """Generate report for the TensorFlow Lite compatibility information."""
- if compat_info.conversion_errors:
- return Table(
- [
- Column("#", only_for=["plain_text"]),
- Column("Operator", alias="operator"),
- Column(
- "Operator location",
- alias="operator_location",
- fmt=Format(wrap_width=25),
- ),
- Column("Error code", alias="error_code"),
- Column(
- "Error message", alias="error_message", fmt=Format(wrap_width=25)
- ),
- ],
- [
- (
- index + 1,
- err.operator,
- ", ".join(err.location),
- err.code.name,
- err.message,
- )
- for index, err in enumerate(compat_info.conversion_errors)
- ],
- name="TensorFlow Lite conversion errors",
- alias="tensorflow_lite_conversion_errors",
- )
-
- return Table(
- columns=[
- Column("Reason", alias="reason"),
- Column(
- "Exception details",
- alias="exception_details",
- fmt=Format(wrap_width=40),
- ),
- ],
- rows=[
- (
- "TensorFlow Lite compatibility check failed with exception",
- str(compat_info.conversion_exception),
- ),
- ],
- name="TensorFlow Lite compatibility errors",
- alias="tflite_compatibility",
- )
-
-
def report_cortex_a_operators(op_compat: CortexACompatibilityInfo) -> Report:
"""Generate report for the operators."""
return Table(
@@ -132,7 +82,7 @@ def cortex_a_formatters(data: Any) -> Callable[[Any], Report]:
return report_target
if isinstance(data, TFLiteCompatibilityInfo):
- return report_tflite_compatiblity
+ return report_tflite_compatiblity # type: ignore
if isinstance(data, CortexACompatibilityInfo):
return report_cortex_a_operators
diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py
index daae4f4..a9f9eac 100644
--- a/src/mlia/target/ethos_u/advice_generation.py
+++ b/src/mlia/target/ethos_u/advice_generation.py
@@ -12,6 +12,10 @@ from mlia.core.advice_generation import FactBasedAdviceProducer
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common
+from mlia.target.common.reporters import handle_tflite_check_failed_common
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.ethos_u.data_analysis import AllOperatorsSupportedOnNPU
from mlia.target.ethos_u.data_analysis import HasCPUOnlyOperators
from mlia.target.ethos_u.data_analysis import HasUnsupportedOnNPUOperators
@@ -147,6 +151,22 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
]
)
+ @produce_advice.register
+ @advice_category(AdviceCategory.COMPATIBILITY)
+ def handle_model_is_not_tflite_compatible(
+ self, data_item: ModelIsNotTFLiteCompatible
+ ) -> None:
+ """Advice for TensorFlow Lite compatibility."""
+ handle_model_is_not_tflite_compatible_common(self, data_item)
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.COMPATIBILITY)
+ def handle_tflite_check_failed(
+ self, _data_item: TFLiteCompatibilityCheckFailed
+ ) -> None:
+ """Advice for the failed TensorFlow Lite compatibility checks."""
+ handle_tflite_check_failed_common(self, _data_item)
+
@staticmethod
def get_next_optimization_targets(
opt_type: list[OptimizationSettings],
diff --git a/src/mlia/target/ethos_u/data_analysis.py b/src/mlia/target/ethos_u/data_analysis.py
index 6b66734..3df4bff 100644
--- a/src/mlia/target/ethos_u/data_analysis.py
+++ b/src/mlia/target/ethos_u/data_analysis.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U data analysis module."""
from __future__ import annotations
@@ -11,6 +11,8 @@ from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import analyze_tflite_compatibility_common
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
@@ -151,3 +153,8 @@ class EthosUDataAnalyzer(FactExtractor):
diffs.append(diff)
self.add_fact(OptimizationResults(diffs))
+
+ @analyze_data.register
+ def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
+ """Analyze TensorFlow Lite compatibility information."""
+ analyze_tflite_compatibility_common(self, data_item)
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index 4fdfe96..8348393 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -17,6 +17,9 @@ from mlia.nn.tensorflow.config import get_tflite_model
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.optimizations.select import get_optimizer
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import EthosUPerformanceEstimator
@@ -36,8 +39,16 @@ class EthosUOperatorCompatibility(ContextAwareDataCollector):
self.model = model
self.target_config = target_config
- def collect_data(self) -> Operators:
+ def collect_data(self) -> Operators | TFLiteCompatibilityInfo | None:
"""Collect operator compatibility information."""
+ if not is_tflite_model(self.model):
+ with log_action("Checking TensorFlow Lite compatibility ..."):
+ tflite_checker = TFLiteChecker()
+ tflite_compat = tflite_checker.check_compatibility(self.model)
+
+ if not tflite_compat.compatible:
+ return tflite_compat
+
tflite_model = get_tflite_model(self.model, self.context)
with log_action("Checking operator compatibility ..."):
diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py
index b9c89e8..1b15c55 100644
--- a/src/mlia/target/ethos_u/handlers.py
+++ b/src/mlia/target/ethos_u/handlers.py
@@ -8,6 +8,7 @@ import logging
from mlia.backend.vela.compat import Operators
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.target.ethos_u.events import EthosUAdvisorEventHandler
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
@@ -49,6 +50,9 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
space=True,
)
+ if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible:
+ self.reporter.submit(data_item, delay_print=True)
+
def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
"""Handle EthosUAdvisorStarted event."""
self.reporter.submit(event.target_config)
diff --git a/src/mlia/target/ethos_u/reporters.py b/src/mlia/target/ethos_u/reporters.py
index 2a5b5d3..4964462 100644
--- a/src/mlia/target/ethos_u/reporters.py
+++ b/src/mlia/target/ethos_u/reporters.py
@@ -23,6 +23,8 @@ from mlia.core.reporting import Report
from mlia.core.reporting import ReportItem
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import report_tflite_compatiblity
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import PerformanceMetrics
from mlia.utils.console import style_improvement
@@ -363,23 +365,31 @@ def report_perf_metrics(
def ethos_u_formatters(data: Any) -> Callable[[Any], Report]:
"""Find appropriate formatter for the provided data."""
+ report: Callable[[Any], Report] | None = None
+
if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2):
- return report_perf_metrics
+ report = report_perf_metrics
- if is_list_of(data, Advice):
- return report_advice
+ elif is_list_of(data, Advice):
+ report = report_advice
- if is_list_of(data, Operator):
- return report_operators
+ elif is_list_of(data, Operator):
+ report = report_operators
- if isinstance(data, Operators):
- return report_operators_stat
+ elif isinstance(data, Operators):
+ report = report_operators_stat
- if isinstance(data, EthosUConfiguration):
- return report_target_details
+ elif isinstance(data, EthosUConfiguration):
+ report = report_target_details
- if isinstance(data, (list, tuple)):
+ elif isinstance(data, (list, tuple)):
formatters = [ethos_u_formatters(item) for item in data]
- return CompoundFormatter(formatters)
+ report = CompoundFormatter(formatters)
+
+ elif isinstance(data, TFLiteCompatibilityInfo):
+ report = report_tflite_compatiblity
+
+ else:
+ raise Exception(f"Unable to find appropriate formatter for {data}")
- raise Exception(f"Unable to find appropriate formatter for {data}")
+ return report # type: ignore
diff --git a/src/mlia/target/tosa/advice_generation.py b/src/mlia/target/tosa/advice_generation.py
index b8b9abf..ad321b2 100644
--- a/src/mlia/target/tosa/advice_generation.py
+++ b/src/mlia/target/tosa/advice_generation.py
@@ -7,6 +7,10 @@ from mlia.core.advice_generation import advice_category
from mlia.core.advice_generation import FactBasedAdviceProducer
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
+from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common
+from mlia.target.common.reporters import handle_tflite_check_failed_common
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.tosa.data_analysis import ModelIsNotTOSACompatible
from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
@@ -38,3 +42,19 @@ class TOSAAdviceProducer(FactBasedAdviceProducer):
"Please, refer to the operators table for more information."
]
)
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.COMPATIBILITY)
+ def handle_model_is_not_tflite_compatible(
+ self, data_item: ModelIsNotTFLiteCompatible
+ ) -> None:
+ """Advice for TensorFlow Lite compatibility."""
+ handle_model_is_not_tflite_compatible_common(self, data_item)
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.COMPATIBILITY)
+ def handle_tflite_check_failed(
+ self, _data_item: TFLiteCompatibilityCheckFailed
+ ) -> None:
+ """Advice for the failed TensorFlow Lite compatibility checks."""
+ handle_tflite_check_failed_common(self, _data_item)
diff --git a/src/mlia/target/tosa/data_analysis.py b/src/mlia/target/tosa/data_analysis.py
index 7cbd61d..7b31441 100644
--- a/src/mlia/target/tosa/data_analysis.py
+++ b/src/mlia/target/tosa/data_analysis.py
@@ -1,6 +1,8 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA data analysis module."""
+from __future__ import annotations
+
from dataclasses import dataclass
from functools import singledispatchmethod
@@ -8,6 +10,8 @@ from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import analyze_tflite_compatibility_common
@dataclass
@@ -34,3 +38,8 @@ class TOSADataAnalyzer(FactExtractor):
self.add_fact(ModelIsTOSACompatible())
else:
self.add_fact(ModelIsNotTOSACompatible())
+
+ @analyze_data.register
+ def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
+ """Analyze TensorFlow Lite compatibility information."""
+ analyze_tflite_compatibility_common(self, data_item)
diff --git a/src/mlia/target/tosa/data_collection.py b/src/mlia/target/tosa/data_collection.py
index 105c501..19ea6eb 100644
--- a/src/mlia/target/tosa/data_collection.py
+++ b/src/mlia/target/tosa/data_collection.py
@@ -1,12 +1,17 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA data collection module."""
+from __future__ import annotations
+
from pathlib import Path
from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.data_collection import ContextAwareDataCollector
from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.utils.logging import log_action
@@ -17,8 +22,16 @@ class TOSAOperatorCompatibility(ContextAwareDataCollector):
"""Init the data collector."""
self.model = model
- def collect_data(self) -> TOSACompatibilityInfo:
+ def collect_data(self) -> TFLiteCompatibilityInfo | TOSACompatibilityInfo | None:
"""Collect TOSA compatibility information."""
+ if not is_tflite_model(self.model):
+ with log_action("Checking TensorFlow Lite compatibility ..."):
+ tflite_checker = TFLiteChecker()
+ tflite_compat = tflite_checker.check_compatibility(self.model)
+
+ if not tflite_compat.compatible:
+ return tflite_compat
+
tflite_model = get_tflite_model(self.model, self.context)
with log_action("Checking operator compatibility ..."):
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
index 131afa7..26e7226 100644
--- a/src/mlia/target/tosa/handlers.py
+++ b/src/mlia/target/tosa/handlers.py
@@ -9,6 +9,7 @@ import logging
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.target.tosa.events import TOSAAdvisorEventHandler
from mlia.target.tosa.events import TOSAAdvisorStartedEvent
from mlia.target.tosa.reporters import tosa_formatters
@@ -34,3 +35,6 @@ class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
if isinstance(data_item, TOSACompatibilityInfo):
self.reporter.submit(data_item, delay_print=True)
+
+ if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible:
+ self.reporter.submit(data_item, delay_print=True)
diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py
index f54c06b..decae0c 100644
--- a/src/mlia/target/tosa/reporters.py
+++ b/src/mlia/target/tosa/reporters.py
@@ -22,6 +22,8 @@ from mlia.core.reporting import NestedReport
from mlia.core.reporting import Report
from mlia.core.reporting import ReportItem
from mlia.core.reporting import Table
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.target.common.reporters import report_tflite_compatiblity
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.metadata import TOSAMetadata
from mlia.utils.console import style_improvement
@@ -163,4 +165,7 @@ def tosa_formatters(data: Any) -> Callable[[Any], Report]:
if isinstance(data, TOSACompatibilityInfo):
return report_tosa_compatibility
+ if isinstance(data, TFLiteCompatibilityInfo):
+ return report_tflite_compatiblity # type: ignore
+
raise Exception(f"Unable to find appropriate formatter for {data}")
diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py
index 9596d47..916bdc1 100644
--- a/tests/test_target_cortex_a_advice_generation.py
+++ b/tests/test_target_cortex_a_advice_generation.py
@@ -13,13 +13,13 @@ from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
from mlia.core.context import ExecutionContext
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+from mlia.target.common.reporters import ModelHasCustomOperators
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.cortex_a.advice_generation import CortexAAdviceProducer
from mlia.target.cortex_a.config import CortexAConfiguration
-from mlia.target.cortex_a.data_analysis import ModelHasCustomOperators
from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible
-from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible
-from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed
VERSION = CortexAConfiguration.load_profile("cortex-a").armnn_tflite_delegate_version
BACKEND_INFO = (
diff --git a/tests/test_target_cortex_a_data_analysis.py b/tests/test_target_cortex_a_data_analysis.py
index 0a6b490..e033ef9 100644
--- a/tests/test_target_cortex_a_data_analysis.py
+++ b/tests/test_target_cortex_a_data_analysis.py
@@ -15,13 +15,13 @@ from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
+from mlia.target.common.reporters import ModelHasCustomOperators
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer
-from mlia.target.cortex_a.data_analysis import ModelHasCustomOperators
from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible
from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible
-from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible
-from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.target.cortex_a.operators import Operator
diff --git a/tests/test_target_ethos_u_data_analysis.py b/tests/test_target_ethos_u_data_analysis.py
index 8e63946..80f0603 100644
--- a/tests/test_target_ethos_u_data_analysis.py
+++ b/tests/test_target_ethos_u_data_analysis.py
@@ -13,6 +13,13 @@ from mlia.backend.vela.compat import Operators
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
+from mlia.target.common.reporters import ModelHasCustomOperators
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.data_analysis import AllOperatorsSupportedOnNPU
from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer
@@ -139,6 +146,51 @@ def test_perf_metrics_diff() -> None:
),
[],
],
+ [
+ TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE),
+ [],
+ ],
+ [
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR
+ ),
+ [ModelHasCustomOperators()],
+ ],
+ [
+ TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.UNKNOWN_ERROR),
+ [TFLiteCompatibilityCheckFailed()],
+ ],
+ [
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR
+ ),
+ [ModelIsNotTFLiteCompatible(custom_ops=[], flex_ops=[])],
+ ],
+ [
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
+ conversion_errors=[
+ TFLiteConversionError(
+ "error",
+ TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS,
+ "custom_op1",
+ [],
+ ),
+ TFLiteConversionError(
+ "error",
+ TFLiteConversionErrorCode.NEEDS_FLEX_OPS,
+ "flex_op1",
+ [],
+ ),
+ ],
+ ),
+ [
+ ModelIsNotTFLiteCompatible(
+ custom_ops=["custom_op1"],
+ flex_ops=["flex_op1"],
+ )
+ ],
+ ],
],
)
def test_ethos_u_data_analyzer(
diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py
index b8014e4..debeeb2 100644
--- a/tests/test_target_ethos_u_reporters.py
+++ b/tests/test_target_ethos_u_reporters.py
@@ -3,6 +3,7 @@
"""Tests for reports module."""
from __future__ import annotations
+from typing import Any
from typing import cast
import pytest
@@ -11,7 +12,10 @@ from mlia.backend.vela.compat import NpuSupported
from mlia.backend.vela.compat import Operator
from mlia.core.reporting import Report
from mlia.core.reporting import Table
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus
from mlia.target.ethos_u.config import EthosUConfiguration
+from mlia.target.ethos_u.reporters import ethos_u_formatters
from mlia.target.ethos_u.reporters import report_operators
from mlia.target.ethos_u.reporters import report_target_details
from mlia.target.registry import profile
@@ -231,3 +235,23 @@ def test_report_target_details(
json_dict = report.to_json()
assert json_dict == expected_json_dict
+
+
+@pytest.mark.parametrize(
+ "data",
+ (TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE),),
+)
+def test_ethos_u_formatters(data: Any) -> None:
+ """Test function ethos_u_formatters() with valid input."""
+ formatter = ethos_u_formatters(data)
+ report = formatter(data)
+ assert isinstance(report, Report)
+
+
+def test_ethos_u_formatters_invalid_data() -> None:
+ """Test function ethos_u_formatters() with invalid input."""
+ with pytest.raises(
+ Exception,
+ match=r"^Unable to find appropriate formatter for .*",
+ ):
+ ethos_u_formatters(200)
diff --git a/tests/test_target_tosa_data_analysis.py b/tests/test_target_tosa_data_analysis.py
index 41e977f..23adcc8 100644
--- a/tests/test_target_tosa_data_analysis.py
+++ b/tests/test_target_tosa_data_analysis.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for TOSA data analysis module."""
from __future__ import annotations
@@ -8,6 +8,13 @@ import pytest
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
+from mlia.target.common.reporters import ModelHasCustomOperators
+from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
+from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed
from mlia.target.tosa.data_analysis import ModelIsNotTOSACompatible
from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
@@ -24,6 +31,51 @@ from mlia.target.tosa.data_analysis import TOSADataAnalyzer
TOSACompatibilityInfo(False, []),
[ModelIsNotTOSACompatible()],
],
+ [
+ TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE),
+ [],
+ ],
+ [
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR
+ ),
+ [ModelHasCustomOperators()],
+ ],
+ [
+ TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.UNKNOWN_ERROR),
+ [TFLiteCompatibilityCheckFailed()],
+ ],
+ [
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR
+ ),
+ [ModelIsNotTFLiteCompatible(custom_ops=[], flex_ops=[])],
+ ],
+ [
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
+ conversion_errors=[
+ TFLiteConversionError(
+ "error",
+ TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS,
+ "custom_op1",
+ [],
+ ),
+ TFLiteConversionError(
+ "error",
+ TFLiteConversionErrorCode.NEEDS_FLEX_OPS,
+ "flex_op1",
+ [],
+ ),
+ ],
+ ),
+ [
+ ModelIsNotTFLiteCompatible(
+ custom_ops=["custom_op1"],
+ flex_ops=["flex_op1"],
+ )
+ ],
+ ],
],
)
def test_tosa_data_analyzer(input_data: DataItem, expected_facts: list[Fact]) -> None: