aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/devices/cortexa
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10 /src/mlia/devices/cortexa
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'src/mlia/devices/cortexa')
-rw-r--r--src/mlia/devices/cortexa/advice_generation.py35
-rw-r--r--src/mlia/devices/cortexa/advisor.py6
-rw-r--r--src/mlia/devices/cortexa/data_analysis.py31
-rw-r--r--src/mlia/devices/cortexa/data_collection.py25
-rw-r--r--src/mlia/devices/cortexa/handlers.py4
-rw-r--r--src/mlia/devices/cortexa/operators.py8
-rw-r--r--src/mlia/devices/cortexa/reporters.py108
7 files changed, 199 insertions, 18 deletions
diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py
index 33d5a5f..0f3553f 100644
--- a/src/mlia/devices/cortexa/advice_generation.py
+++ b/src/mlia/devices/cortexa/advice_generation.py
@@ -9,6 +9,7 @@ from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible
from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
+from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
class CortexAAdviceProducer(FactBasedAdviceProducer):
@@ -38,3 +39,37 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
"Please, refer to the operators table for more information."
]
)
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_not_tflite_compatible(
+ self, data_item: ModelIsNotTFLiteCompatible
+ ) -> None:
+ """Advice for TensorFlow Lite compatibility."""
+ if data_item.flex_ops:
+ self.add_advice(
+ [
+ "The following operators are not natively "
+ "supported by TensorFlow Lite: "
+ f"{', '.join(data_item.flex_ops)}.",
+ "Please refer to the TensorFlow documentation for more details.",
+ ]
+ )
+
+ if data_item.custom_ops:
+ self.add_advice(
+ [
+ "The following operators are custom and not natively "
+ "supported by TensorFlow Lite: "
+ f"{', '.join(data_item.custom_ops)}.",
+ "Please refer to the TensorFlow documentation for more details.",
+ ]
+ )
+
+ if not data_item.flex_ops and not data_item.custom_ops:
+ self.add_advice(
+ [
+ "Model could not be converted into TensorFlow Lite format.",
+ "Please refer to the table for more details.",
+ ]
+ )
diff --git a/src/mlia/devices/cortexa/advisor.py b/src/mlia/devices/cortexa/advisor.py
index 98c155b..ffbbea5 100644
--- a/src/mlia/devices/cortexa/advisor.py
+++ b/src/mlia/devices/cortexa/advisor.py
@@ -68,16 +68,14 @@ def configure_and_get_cortexa_advisor(
target_profile: str,
model: str | Path,
output: PathOrFileLike | None = None,
- **extra_args: Any,
+ **_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
if context.event_handlers is None:
context.event_handlers = [CortexAEventHandler(output)]
if context.config_parameters is None:
- context.config_parameters = _get_config_parameters(
- model, target_profile, **extra_args
- )
+ context.config_parameters = _get_config_parameters(model, target_profile)
return CortexAInferenceAdvisor()
diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py
index dff95ce..d2b6f35 100644
--- a/src/mlia/devices/cortexa/data_analysis.py
+++ b/src/mlia/devices/cortexa/data_analysis.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A data analysis module."""
+from __future__ import annotations
+
from dataclasses import dataclass
from functools import singledispatchmethod
@@ -8,6 +10,8 @@ from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
class CortexADataAnalyzer(FactExtractor):
@@ -27,6 +31,25 @@ class CortexADataAnalyzer(FactExtractor):
else:
self.add_fact(ModelIsNotCortexACompatible())
+ @analyze_data.register
+ def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
+ """Analyze TensorFlow Lite compatibility information."""
+ if data_item.compatible:
+ return
+
+ custom_ops, flex_ops = [], []
+ if data_item.conversion_errors:
+ custom_ops = data_item.unsupported_ops_by_code(
+ TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS
+ )
+ flex_ops = data_item.unsupported_ops_by_code(
+ TFLiteConversionErrorCode.NEEDS_FLEX_OPS
+ )
+
+ self.add_fact(
+ ModelIsNotTFLiteCompatible(custom_ops=custom_ops, flex_ops=flex_ops)
+ )
+
@dataclass
class ModelIsCortexACompatible(Fact):
@@ -36,3 +59,11 @@ class ModelIsCortexACompatible(Fact):
@dataclass
class ModelIsNotCortexACompatible(Fact):
"""Model is not compatible with Cortex-A."""
+
+
+@dataclass
+class ModelIsNotTFLiteCompatible(Fact):
+ """Model could not be converted into TensorFlow Lite format."""
+
+ custom_ops: list[str] | None = None
+ flex_ops: list[str] | None = None
diff --git a/src/mlia/devices/cortexa/data_collection.py b/src/mlia/devices/cortexa/data_collection.py
index 00c95e6..f4d5a82 100644
--- a/src/mlia/devices/cortexa/data_collection.py
+++ b/src/mlia/devices/cortexa/data_collection.py
@@ -10,6 +10,11 @@ from mlia.core.data_collection import ContextAwareDataCollector
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
from mlia.devices.cortexa.operators import get_cortex_a_compatibility_info
from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.utils.logging import log_action
+
logger = logging.getLogger(__name__)
@@ -21,14 +26,24 @@ class CortexAOperatorCompatibility(ContextAwareDataCollector):
"""Init operator compatibility data collector."""
self.model = model
- def collect_data(self) -> CortexACompatibilityInfo:
+ def collect_data(self) -> TFLiteCompatibilityInfo | CortexACompatibilityInfo | None:
"""Collect operator compatibility information."""
+ if not is_tflite_model(self.model):
+ with log_action("Checking TensorFlow Lite compatibility ..."):
+ tflite_checker = TFLiteChecker()
+ tflite_compat = tflite_checker.check_compatibility(self.model)
+
+ if not tflite_compat.compatible:
+ return tflite_compat
+
tflite_model = get_tflite_model(self.model, self.context)
- logger.info("Checking operator compatibility ...")
- ops = get_cortex_a_compatibility_info(Path(tflite_model.model_path))
- logger.info("Done\n")
- return ops
+ with log_action("Checking operator compatibility ..."):
+ return (
+ get_cortex_a_compatibility_info( # pylint: disable=assignment-from-none
+ Path(tflite_model.model_path)
+ )
+ )
@classmethod
def name(cls) -> str:
diff --git a/src/mlia/devices/cortexa/handlers.py b/src/mlia/devices/cortexa/handlers.py
index f54ceff..7ed2b75 100644
--- a/src/mlia/devices/cortexa/handlers.py
+++ b/src/mlia/devices/cortexa/handlers.py
@@ -12,6 +12,7 @@ from mlia.devices.cortexa.events import CortexAAdvisorEventHandler
from mlia.devices.cortexa.events import CortexAAdvisorStartedEvent
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
from mlia.devices.cortexa.reporters import cortex_a_formatters
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
logger = logging.getLogger(__name__)
@@ -30,6 +31,9 @@ class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler):
if isinstance(data_item, CortexACompatibilityInfo):
self.reporter.submit(data_item.operators, delay_print=True)
+ if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible:
+ self.reporter.submit(data_item, delay_print=True)
+
def on_cortex_a_advisor_started(self, event: CortexAAdvisorStartedEvent) -> None:
"""Handle CortexAAdvisorStarted event."""
self.reporter.submit(event.device)
diff --git a/src/mlia/devices/cortexa/operators.py b/src/mlia/devices/cortexa/operators.py
index 6a314b7..8fd2571 100644
--- a/src/mlia/devices/cortexa/operators.py
+++ b/src/mlia/devices/cortexa/operators.py
@@ -21,9 +21,11 @@ class CortexACompatibilityInfo:
"""Model's operators."""
cortex_a_compatible: bool
- operators: list[Operator]
+ operators: list[Operator] | None = None
-def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo:
+def get_cortex_a_compatibility_info(
+ _model_path: Path,
+) -> CortexACompatibilityInfo | None:
"""Return list of model's operators."""
- raise NotImplementedError()
+ return None
diff --git a/src/mlia/devices/cortexa/reporters.py b/src/mlia/devices/cortexa/reporters.py
index 076b9ca..a55caba 100644
--- a/src/mlia/devices/cortexa/reporters.py
+++ b/src/mlia/devices/cortexa/reporters.py
@@ -7,25 +7,118 @@ from typing import Any
from typing import Callable
from mlia.core.advice_generation import Advice
+from mlia.core.reporters import report_advice
+from mlia.core.reporting import Cell
+from mlia.core.reporting import Column
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
from mlia.core.reporting import Report
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import Table
from mlia.devices.cortexa.config import CortexAConfiguration
from mlia.devices.cortexa.operators import Operator
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.utils.console import style_improvement
from mlia.utils.types import is_list_of
def report_device(device: CortexAConfiguration) -> Report:
"""Generate report for the device."""
- raise NotImplementedError()
+ return NestedReport(
+ "Device information",
+ "device",
+ [
+ ReportItem("Target", alias="target", value=device.target),
+ ],
+ )
-def report_advice(advice: list[Advice]) -> Report:
- """Generate report for the advice."""
- raise NotImplementedError()
+def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report:
+ """Generate report for the TensorFlow Lite compatibility information."""
+ if compat_info.conversion_errors:
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column("Operator", alias="operator"),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=25),
+ ),
+ Column("Error code", alias="error_code"),
+ Column(
+ "Error message", alias="error_message", fmt=Format(wrap_width=25)
+ ),
+ ],
+ [
+ (
+ index + 1,
+ err.operator,
+ ", ".join(err.location),
+ err.code.name,
+ err.message,
+ )
+ for index, err in enumerate(compat_info.conversion_errors)
+ ],
+ name="TensorFlow Lite conversion errors",
+ alias="tensorflow_lite_conversion_errors",
+ )
+ return Table(
+ columns=[
+ Column("Reason", alias="reason"),
+ Column(
+ "Exception details",
+ alias="exception_details",
+ fmt=Format(wrap_width=40),
+ ),
+ ],
+ rows=[
+ (
+ "TensorFlow Lite compatibility check failed with exception",
+ str(compat_info.conversion_exception),
+ ),
+ ],
+ name="TensorFlow Lite compatibility errors",
+ alias="tflite_compatibility",
+ )
-def report_cortex_a_operators(operators: list[Operator]) -> Report:
+
+def report_cortex_a_operators(ops: list[Operator]) -> Report:
"""Generate report for the operators."""
- raise NotImplementedError()
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=30),
+ ),
+ Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
+ Column(
+ "Cortex-A compatibility",
+ alias="cortex_a_compatible",
+ fmt=Format(wrap_width=25),
+ ),
+ ],
+ [
+ (
+ index + 1,
+ op.location,
+ op.name,
+ Cell(
+ op.is_cortex_a_compatible,
+ Format(
+ style=style_improvement(op.is_cortex_a_compatible),
+ str_fmt=lambda v: "Compatible" if v else "Not compatible",
+ ),
+ ),
+ )
+ for index, op in enumerate(ops)
+ ],
+ name="Operators",
+ alias="operators",
+ )
def cortex_a_formatters(data: Any) -> Callable[[Any], Report]:
@@ -36,6 +129,9 @@ def cortex_a_formatters(data: Any) -> Callable[[Any], Report]:
if isinstance(data, CortexAConfiguration):
return report_device
+ if isinstance(data, TFLiteCompatibilityInfo):
+ return report_tflite_compatiblity
+
if is_list_of(data, Operator):
return report_cortex_a_operators