aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
-rw-r--r--src/mlia/cli/logging.py35
-rw-r--r--src/mlia/core/reporters.py22
-rw-r--r--src/mlia/devices/cortexa/advice_generation.py35
-rw-r--r--src/mlia/devices/cortexa/advisor.py6
-rw-r--r--src/mlia/devices/cortexa/data_analysis.py31
-rw-r--r--src/mlia/devices/cortexa/data_collection.py25
-rw-r--r--src/mlia/devices/cortexa/handlers.py4
-rw-r--r--src/mlia/devices/cortexa/operators.py8
-rw-r--r--src/mlia/devices/cortexa/reporters.py108
-rw-r--r--src/mlia/devices/ethosu/data_collection.py11
-rw-r--r--src/mlia/devices/ethosu/performance.py114
-rw-r--r--src/mlia/devices/ethosu/reporters.py14
-rw-r--r--src/mlia/devices/tosa/data_collection.py11
-rw-r--r--src/mlia/devices/tosa/reporters.py14
-rw-r--r--src/mlia/nn/tensorflow/config.py14
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py132
-rw-r--r--src/mlia/nn/tensorflow/utils.py159
-rw-r--r--src/mlia/utils/logging.py60
-rw-r--r--tests/test_devices_cortex_a_data_analysis.py35
-rw-r--r--tests/test_devices_cortexa_advice_generation.py (renamed from tests/test_devices_cortex_a_advice_generation.py)38
-rw-r--r--tests/test_devices_cortexa_data_analysis.py72
-rw-r--r--tests/test_devices_cortexa_data_collection.py (renamed from tests/test_devices_cortex_a_data_collection.py)0
-rw-r--r--tests/test_nn_tensorflow_tflite_compat.py210
-rw-r--r--tests/test_nn_tensorflow_utils.py44
-rw-r--r--tests/test_utils_logging.py24
25 files changed, 950 insertions, 276 deletions
diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py
index e786394..5c5c4b8 100644
--- a/src/mlia/cli/logging.py
+++ b/src/mlia/cli/logging.py
@@ -6,6 +6,7 @@ from __future__ import annotations
import logging
import sys
from pathlib import Path
+from typing import Iterable
from mlia.utils.logging import attach_handlers
from mlia.utils.logging import create_log_handler
@@ -31,34 +32,33 @@ def setup_logging(
:param verbose: enable extended logging for the tools loggers
:param log_filename: name of the log file in the logs directory
"""
- mlia_logger, *tools_loggers = (
+ mlia_logger, tensorflow_logger, py_warnings_logger = (
logging.getLogger(logger_name)
for logger_name in ["mlia", "tensorflow", "py.warnings"]
)
# enable debug output, actual message filtering depends on
- # the provided parameters and being done on the handlers level
- mlia_logger.setLevel(logging.DEBUG)
+ # the provided parameters and being done at the handlers level
+ for logger in [mlia_logger, tensorflow_logger]:
+ logger.setLevel(logging.DEBUG)
mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
attach_handlers(mlia_handlers, [mlia_logger])
tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
- attach_handlers(tools_handlers, tools_loggers)
+ attach_handlers(tools_handlers, [tensorflow_logger, py_warnings_logger])
def _get_mlia_handlers(
logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
-) -> list[logging.Handler]:
+) -> Iterable[logging.Handler]:
"""Get handlers for the MLIA loggers."""
- result = []
- stdout_handler = create_log_handler(
+ yield create_log_handler(
stream=sys.stdout,
log_level=logging.INFO,
)
- result.append(stdout_handler)
if verbose:
mlia_verbose_handler = create_log_handler(
@@ -67,50 +67,43 @@ def _get_mlia_handlers(
log_format=_CONSOLE_DEBUG_FORMAT,
log_filter=LogFilter.equals(logging.DEBUG),
)
- result.append(mlia_verbose_handler)
+ yield mlia_verbose_handler
if logs_dir:
- mlia_file_handler = create_log_handler(
+ yield create_log_handler(
file_path=_get_log_file(logs_dir, log_filename),
log_level=logging.DEBUG,
log_format=_FILE_DEBUG_FORMAT,
log_filter=LogFilter.skip(logging.INFO),
delay=True,
)
- result.append(mlia_file_handler)
-
- return result
def _get_tools_handlers(
logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
-) -> list[logging.Handler]:
+) -> Iterable[logging.Handler]:
"""Get handler for the tools loggers."""
- result = []
if verbose:
- verbose_stdout_handler = create_log_handler(
+ yield create_log_handler(
stream=sys.stdout,
log_level=logging.DEBUG,
log_format=_CONSOLE_DEBUG_FORMAT,
)
- result.append(verbose_stdout_handler)
if logs_dir:
- file_handler = create_log_handler(
+ yield create_log_handler(
file_path=_get_log_file(logs_dir, log_filename),
log_level=logging.DEBUG,
log_format=_FILE_DEBUG_FORMAT,
delay=True,
)
- result.append(file_handler)
-
- return result
def _get_log_file(logs_dir: str | Path, log_filename: str) -> Path:
"""Get the log file path."""
logs_dir_path = Path(logs_dir)
logs_dir_path.mkdir(exist_ok=True)
+
return logs_dir_path / log_filename
diff --git a/src/mlia/core/reporters.py b/src/mlia/core/reporters.py
new file mode 100644
index 0000000..de73ad7
--- /dev/null
+++ b/src/mlia/core/reporters.py
@@ -0,0 +1,22 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reports module."""
+from __future__ import annotations
+
+from mlia.core.advice_generation import Advice
+from mlia.core.reporting import Column
+from mlia.core.reporting import Report
+from mlia.core.reporting import Table
+
+
+def report_advice(advice: list[Advice]) -> Report:
+ """Generate report for the advice."""
+ return Table(
+ columns=[
+ Column("#", only_for=["plain_text"]),
+ Column("Advice", alias="advice_message"),
+ ],
+ rows=[(i + 1, a.messages) for i, a in enumerate(advice)],
+ name="Advice",
+ alias="advice",
+ )
diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py
index 33d5a5f..0f3553f 100644
--- a/src/mlia/devices/cortexa/advice_generation.py
+++ b/src/mlia/devices/cortexa/advice_generation.py
@@ -9,6 +9,7 @@ from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible
from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
+from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
class CortexAAdviceProducer(FactBasedAdviceProducer):
@@ -38,3 +39,37 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
"Please, refer to the operators table for more information."
]
)
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_not_tflite_compatible(
+ self, data_item: ModelIsNotTFLiteCompatible
+ ) -> None:
+ """Advice for TensorFlow Lite compatibility."""
+ if data_item.flex_ops:
+ self.add_advice(
+ [
+ "The following operators are not natively "
+ "supported by TensorFlow Lite: "
+ f"{', '.join(data_item.flex_ops)}.",
+ "Please refer to the TensorFlow documentation for more details.",
+ ]
+ )
+
+ if data_item.custom_ops:
+ self.add_advice(
+ [
+ "The following operators are custom and not natively "
+ "supported by TensorFlow Lite: "
+ f"{', '.join(data_item.custom_ops)}.",
+ "Please refer to the TensorFlow documentation for more details.",
+ ]
+ )
+
+ if not data_item.flex_ops and not data_item.custom_ops:
+ self.add_advice(
+ [
+ "Model could not be converted into TensorFlow Lite format.",
+ "Please refer to the table for more details.",
+ ]
+ )
diff --git a/src/mlia/devices/cortexa/advisor.py b/src/mlia/devices/cortexa/advisor.py
index 98c155b..ffbbea5 100644
--- a/src/mlia/devices/cortexa/advisor.py
+++ b/src/mlia/devices/cortexa/advisor.py
@@ -68,16 +68,14 @@ def configure_and_get_cortexa_advisor(
target_profile: str,
model: str | Path,
output: PathOrFileLike | None = None,
- **extra_args: Any,
+ **_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
if context.event_handlers is None:
context.event_handlers = [CortexAEventHandler(output)]
if context.config_parameters is None:
- context.config_parameters = _get_config_parameters(
- model, target_profile, **extra_args
- )
+ context.config_parameters = _get_config_parameters(model, target_profile)
return CortexAInferenceAdvisor()
diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py
index dff95ce..d2b6f35 100644
--- a/src/mlia/devices/cortexa/data_analysis.py
+++ b/src/mlia/devices/cortexa/data_analysis.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A data analysis module."""
+from __future__ import annotations
+
from dataclasses import dataclass
from functools import singledispatchmethod
@@ -8,6 +10,8 @@ from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
class CortexADataAnalyzer(FactExtractor):
@@ -27,6 +31,25 @@ class CortexADataAnalyzer(FactExtractor):
else:
self.add_fact(ModelIsNotCortexACompatible())
+ @analyze_data.register
+ def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None:
+ """Analyze TensorFlow Lite compatibility information."""
+ if data_item.compatible:
+ return
+
+ custom_ops, flex_ops = [], []
+ if data_item.conversion_errors:
+ custom_ops = data_item.unsupported_ops_by_code(
+ TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS
+ )
+ flex_ops = data_item.unsupported_ops_by_code(
+ TFLiteConversionErrorCode.NEEDS_FLEX_OPS
+ )
+
+ self.add_fact(
+ ModelIsNotTFLiteCompatible(custom_ops=custom_ops, flex_ops=flex_ops)
+ )
+
@dataclass
class ModelIsCortexACompatible(Fact):
@@ -36,3 +59,11 @@ class ModelIsCortexACompatible(Fact):
@dataclass
class ModelIsNotCortexACompatible(Fact):
"""Model is not compatible with Cortex-A."""
+
+
+@dataclass
+class ModelIsNotTFLiteCompatible(Fact):
+ """Model could not be converted into TensorFlow Lite format."""
+
+ custom_ops: list[str] | None = None
+ flex_ops: list[str] | None = None
diff --git a/src/mlia/devices/cortexa/data_collection.py b/src/mlia/devices/cortexa/data_collection.py
index 00c95e6..f4d5a82 100644
--- a/src/mlia/devices/cortexa/data_collection.py
+++ b/src/mlia/devices/cortexa/data_collection.py
@@ -10,6 +10,11 @@ from mlia.core.data_collection import ContextAwareDataCollector
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
from mlia.devices.cortexa.operators import get_cortex_a_compatibility_info
from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.utils.logging import log_action
+
logger = logging.getLogger(__name__)
@@ -21,14 +26,24 @@ class CortexAOperatorCompatibility(ContextAwareDataCollector):
"""Init operator compatibility data collector."""
self.model = model
- def collect_data(self) -> CortexACompatibilityInfo:
+ def collect_data(self) -> TFLiteCompatibilityInfo | CortexACompatibilityInfo | None:
"""Collect operator compatibility information."""
+ if not is_tflite_model(self.model):
+ with log_action("Checking TensorFlow Lite compatibility ..."):
+ tflite_checker = TFLiteChecker()
+ tflite_compat = tflite_checker.check_compatibility(self.model)
+
+ if not tflite_compat.compatible:
+ return tflite_compat
+
tflite_model = get_tflite_model(self.model, self.context)
- logger.info("Checking operator compatibility ...")
- ops = get_cortex_a_compatibility_info(Path(tflite_model.model_path))
- logger.info("Done\n")
- return ops
+ with log_action("Checking operator compatibility ..."):
+ return (
+ get_cortex_a_compatibility_info( # pylint: disable=assignment-from-none
+ Path(tflite_model.model_path)
+ )
+ )
@classmethod
def name(cls) -> str:
diff --git a/src/mlia/devices/cortexa/handlers.py b/src/mlia/devices/cortexa/handlers.py
index f54ceff..7ed2b75 100644
--- a/src/mlia/devices/cortexa/handlers.py
+++ b/src/mlia/devices/cortexa/handlers.py
@@ -12,6 +12,7 @@ from mlia.devices.cortexa.events import CortexAAdvisorEventHandler
from mlia.devices.cortexa.events import CortexAAdvisorStartedEvent
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
from mlia.devices.cortexa.reporters import cortex_a_formatters
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
logger = logging.getLogger(__name__)
@@ -30,6 +31,9 @@ class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler):
if isinstance(data_item, CortexACompatibilityInfo):
self.reporter.submit(data_item.operators, delay_print=True)
+ if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible:
+ self.reporter.submit(data_item, delay_print=True)
+
def on_cortex_a_advisor_started(self, event: CortexAAdvisorStartedEvent) -> None:
"""Handle CortexAAdvisorStarted event."""
self.reporter.submit(event.device)
diff --git a/src/mlia/devices/cortexa/operators.py b/src/mlia/devices/cortexa/operators.py
index 6a314b7..8fd2571 100644
--- a/src/mlia/devices/cortexa/operators.py
+++ b/src/mlia/devices/cortexa/operators.py
@@ -21,9 +21,11 @@ class CortexACompatibilityInfo:
"""Model's operators."""
cortex_a_compatible: bool
- operators: list[Operator]
+ operators: list[Operator] | None = None
-def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo:
+def get_cortex_a_compatibility_info(
+ _model_path: Path,
+) -> CortexACompatibilityInfo | None:
"""Return list of model's operators."""
- raise NotImplementedError()
+ return None
diff --git a/src/mlia/devices/cortexa/reporters.py b/src/mlia/devices/cortexa/reporters.py
index 076b9ca..a55caba 100644
--- a/src/mlia/devices/cortexa/reporters.py
+++ b/src/mlia/devices/cortexa/reporters.py
@@ -7,25 +7,118 @@ from typing import Any
from typing import Callable
from mlia.core.advice_generation import Advice
+from mlia.core.reporters import report_advice
+from mlia.core.reporting import Cell
+from mlia.core.reporting import Column
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
from mlia.core.reporting import Report
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import Table
from mlia.devices.cortexa.config import CortexAConfiguration
from mlia.devices.cortexa.operators import Operator
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.utils.console import style_improvement
from mlia.utils.types import is_list_of
def report_device(device: CortexAConfiguration) -> Report:
"""Generate report for the device."""
- raise NotImplementedError()
+ return NestedReport(
+ "Device information",
+ "device",
+ [
+ ReportItem("Target", alias="target", value=device.target),
+ ],
+ )
-def report_advice(advice: list[Advice]) -> Report:
- """Generate report for the advice."""
- raise NotImplementedError()
+def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report:
+ """Generate report for the TensorFlow Lite compatibility information."""
+ if compat_info.conversion_errors:
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column("Operator", alias="operator"),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=25),
+ ),
+ Column("Error code", alias="error_code"),
+ Column(
+ "Error message", alias="error_message", fmt=Format(wrap_width=25)
+ ),
+ ],
+ [
+ (
+ index + 1,
+ err.operator,
+ ", ".join(err.location),
+ err.code.name,
+ err.message,
+ )
+ for index, err in enumerate(compat_info.conversion_errors)
+ ],
+ name="TensorFlow Lite conversion errors",
+ alias="tensorflow_lite_conversion_errors",
+ )
+ return Table(
+ columns=[
+ Column("Reason", alias="reason"),
+ Column(
+ "Exception details",
+ alias="exception_details",
+ fmt=Format(wrap_width=40),
+ ),
+ ],
+ rows=[
+ (
+ "TensorFlow Lite compatibility check failed with exception",
+ str(compat_info.conversion_exception),
+ ),
+ ],
+ name="TensorFlow Lite compatibility errors",
+ alias="tflite_compatibility",
+ )
-def report_cortex_a_operators(operators: list[Operator]) -> Report:
+
+def report_cortex_a_operators(ops: list[Operator]) -> Report:
"""Generate report for the operators."""
- raise NotImplementedError()
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=30),
+ ),
+ Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
+ Column(
+ "Cortex-A compatibility",
+ alias="cortex_a_compatible",
+ fmt=Format(wrap_width=25),
+ ),
+ ],
+ [
+ (
+ index + 1,
+ op.location,
+ op.name,
+ Cell(
+ op.is_cortex_a_compatible,
+ Format(
+ style=style_improvement(op.is_cortex_a_compatible),
+ str_fmt=lambda v: "Compatible" if v else "Not compatible",
+ ),
+ ),
+ )
+ for index, op in enumerate(ops)
+ ],
+ name="Operators",
+ alias="operators",
+ )
def cortex_a_formatters(data: Any) -> Callable[[Any], Report]:
@@ -36,6 +129,9 @@ def cortex_a_formatters(data: Any) -> Callable[[Any], Report]:
if isinstance(data, CortexAConfiguration):
return report_device
+ if isinstance(data, TFLiteCompatibilityInfo):
+ return report_tflite_compatiblity
+
if is_list_of(data, Operator):
return report_cortex_a_operators
diff --git a/src/mlia/devices/ethosu/data_collection.py b/src/mlia/devices/ethosu/data_collection.py
index 6ddebac..c8d5293 100644
--- a/src/mlia/devices/ethosu/data_collection.py
+++ b/src/mlia/devices/ethosu/data_collection.py
@@ -22,6 +22,7 @@ from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.tools.vela_wrapper import Operators
from mlia.tools.vela_wrapper import supported_operators
+from mlia.utils.logging import log_action
from mlia.utils.types import is_list_of
logger = logging.getLogger(__name__)
@@ -39,12 +40,10 @@ class EthosUOperatorCompatibility(ContextAwareDataCollector):
"""Collect operator compatibility information."""
tflite_model = get_tflite_model(self.model, self.context)
- logger.info("Checking operator compatibility ...")
- ops = supported_operators(
- Path(tflite_model.model_path), self.device.compiler_options
- )
- logger.info("Done\n")
- return ops
+ with log_action("Checking operator compatibility ..."):
+ return supported_operators(
+ Path(tflite_model.model_path), self.device.compiler_options
+ )
@classmethod
def name(cls) -> str:
diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py
index acc82e0..431dd89 100644
--- a/src/mlia/devices/ethosu/performance.py
+++ b/src/mlia/devices/ethosu/performance.py
@@ -17,6 +17,7 @@ from mlia.devices.ethosu.config import EthosUConfiguration
from mlia.nn.tensorflow.config import get_tflite_model
from mlia.nn.tensorflow.config import ModelConfiguration
from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.utils.logging import log_action
logger = logging.getLogger(__name__)
@@ -125,25 +126,24 @@ class VelaPerformanceEstimator(
def estimate(self, model: Path | ModelConfiguration) -> MemoryUsage:
"""Estimate performance."""
- logger.info("Getting the memory usage metrics ...")
-
- model_path = (
- Path(model.model_path) if isinstance(model, ModelConfiguration) else model
- )
-
- vela_perf_metrics = vela.estimate_performance(
- model_path, self.device.compiler_options
- )
-
- memory_usage = MemoryUsage(
- vela_perf_metrics.sram_memory_area_size,
- vela_perf_metrics.dram_memory_area_size,
- vela_perf_metrics.unknown_memory_area_size,
- vela_perf_metrics.on_chip_flash_memory_area_size,
- vela_perf_metrics.off_chip_flash_memory_area_size,
- )
- logger.info("Done\n")
- return memory_usage
+ with log_action("Getting the memory usage metrics ..."):
+ model_path = (
+ Path(model.model_path)
+ if isinstance(model, ModelConfiguration)
+ else model
+ )
+
+ vela_perf_metrics = vela.estimate_performance(
+ model_path, self.device.compiler_options
+ )
+
+ return MemoryUsage(
+ vela_perf_metrics.sram_memory_area_size,
+ vela_perf_metrics.dram_memory_area_size,
+ vela_perf_metrics.unknown_memory_area_size,
+ vela_perf_metrics.on_chip_flash_memory_area_size,
+ vela_perf_metrics.off_chip_flash_memory_area_size,
+ )
class CorstonePerformanceEstimator(
@@ -161,44 +161,44 @@ class CorstonePerformanceEstimator(
def estimate(self, model: Path | ModelConfiguration) -> NPUCycles:
"""Estimate performance."""
- logger.info("Getting the performance metrics for '%s' ...", self.backend)
- logger.info(
- "WARNING: This task may require several minutes (press ctrl-c to interrupt)"
- )
-
- model_path = (
- Path(model.model_path) if isinstance(model, ModelConfiguration) else model
- )
-
- optimized_model_path = self.context.get_model_path(
- f"{model_path.stem}_vela.tflite"
- )
-
- vela.optimize_model(
- model_path, self.device.compiler_options, optimized_model_path
- )
-
- model_info = backend_manager.ModelInfo(model_path=optimized_model_path)
- device_info = backend_manager.DeviceInfo(
- device_type=self.device.target, # type: ignore
- mac=self.device.mac,
- )
-
- corstone_perf_metrics = backend_manager.estimate_performance(
- model_info, device_info, self.backend
- )
-
- npu_cycles = NPUCycles(
- corstone_perf_metrics.npu_active_cycles,
- corstone_perf_metrics.npu_idle_cycles,
- corstone_perf_metrics.npu_total_cycles,
- corstone_perf_metrics.npu_axi0_rd_data_beat_received,
- corstone_perf_metrics.npu_axi0_wr_data_beat_written,
- corstone_perf_metrics.npu_axi1_rd_data_beat_received,
- )
-
- logger.info("Done\n")
- return npu_cycles
+ with log_action(f"Getting the performance metrics for '{self.backend}' ..."):
+ logger.info(
+ "WARNING: This task may require several minutes "
+ "(press ctrl-c to interrupt)"
+ )
+
+ model_path = (
+ Path(model.model_path)
+ if isinstance(model, ModelConfiguration)
+ else model
+ )
+
+ optimized_model_path = self.context.get_model_path(
+ f"{model_path.stem}_vela.tflite"
+ )
+
+ vela.optimize_model(
+ model_path, self.device.compiler_options, optimized_model_path
+ )
+
+ model_info = backend_manager.ModelInfo(model_path=optimized_model_path)
+ device_info = backend_manager.DeviceInfo(
+ device_type=self.device.target, # type: ignore
+ mac=self.device.mac,
+ )
+
+ corstone_perf_metrics = backend_manager.estimate_performance(
+ model_info, device_info, self.backend
+ )
+
+ return NPUCycles(
+ corstone_perf_metrics.npu_active_cycles,
+ corstone_perf_metrics.npu_idle_cycles,
+ corstone_perf_metrics.npu_total_cycles,
+ corstone_perf_metrics.npu_axi0_rd_data_beat_received,
+ corstone_perf_metrics.npu_axi0_wr_data_beat_written,
+ corstone_perf_metrics.npu_axi1_rd_data_beat_received,
+ )
class EthosUPerformanceEstimator(
diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py
index 9181043..f0fcb39 100644
--- a/src/mlia/devices/ethosu/reporters.py
+++ b/src/mlia/devices/ethosu/reporters.py
@@ -8,6 +8,7 @@ from typing import Any
from typing import Callable
from mlia.core.advice_generation import Advice
+from mlia.core.reporters import report_advice
from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
@@ -360,19 +361,6 @@ def report_perf_metrics(
)
-def report_advice(advice: list[Advice]) -> Report:
- """Generate report for the advice."""
- return Table(
- columns=[
- Column("#", only_for=["plain_text"]),
- Column("Advice", alias="advice_message"),
- ],
- rows=[(i + 1, a.messages) for i, a in enumerate(advice)],
- name="Advice",
- alias="advice",
- )
-
-
def ethos_u_formatters(data: Any) -> Callable[[Any], Report]:
"""Find appropriate formatter for the provided data."""
if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2):
diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py
index 843d5ab..3809903 100644
--- a/src/mlia/devices/tosa/data_collection.py
+++ b/src/mlia/devices/tosa/data_collection.py
@@ -1,15 +1,13 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA data collection module."""
-import logging
from pathlib import Path
from mlia.core.data_collection import ContextAwareDataCollector
from mlia.devices.tosa.operators import get_tosa_compatibility_info
from mlia.devices.tosa.operators import TOSACompatibilityInfo
from mlia.nn.tensorflow.config import get_tflite_model
-
-logger = logging.getLogger(__name__)
+from mlia.utils.logging import log_action
class TOSAOperatorCompatibility(ContextAwareDataCollector):
@@ -23,11 +21,8 @@ class TOSAOperatorCompatibility(ContextAwareDataCollector):
"""Collect TOSA compatibility information."""
tflite_model = get_tflite_model(self.model, self.context)
- logger.info("Checking operator compatibility ...")
- tosa_info = get_tosa_compatibility_info(tflite_model.model_path)
- logger.info("Done\n")
-
- return tosa_info
+ with log_action("Checking operator compatibility ..."):
+ return get_tosa_compatibility_info(tflite_model.model_path)
@classmethod
def name(cls) -> str:
diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py
index 4363793..26c93fd 100644
--- a/src/mlia/devices/tosa/reporters.py
+++ b/src/mlia/devices/tosa/reporters.py
@@ -7,6 +7,7 @@ from typing import Any
from typing import Callable
from mlia.core.advice_generation import Advice
+from mlia.core.reporters import report_advice
from mlia.core.reporting import Cell
from mlia.core.reporting import Column
from mlia.core.reporting import Format
@@ -31,19 +32,6 @@ def report_device(device: TOSAConfiguration) -> Report:
)
-def report_advice(advice: list[Advice]) -> Report:
- """Generate report for the advice."""
- return Table(
- columns=[
- Column("#", only_for=["plain_text"]),
- Column("Advice", alias="advice_message"),
- ],
- rows=[(i + 1, a.messages) for i, a in enumerate(advice)],
- name="Advice",
- alias="advice",
- )
-
-
def report_tosa_operators(ops: list[Operator]) -> Report:
"""Generate report for the operators."""
return Table(
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 03d1d0f..0c3133a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -11,12 +11,12 @@ from typing import List
import tensorflow as tf
from mlia.core.context import Context
-from mlia.nn.tensorflow.utils import convert_tf_to_tflite
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
-from mlia.nn.tensorflow.utils import is_tf_model
+from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.utils.logging import log_action
logger = logging.getLogger(__name__)
@@ -53,10 +53,8 @@ class KerasModel(ModelConfiguration):
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- logger.info("Converting Keras to TensorFlow Lite ...")
-
- converted_model = convert_to_tflite(self.get_keras_model(), quantized)
- logger.info("Done\n")
+ with log_action("Converting Keras to TensorFlow Lite ..."):
+ converted_model = convert_to_tflite(self.get_keras_model(), quantized)
save_tflite_model(converted_model, tflite_model_path)
logger.debug(
@@ -95,7 +93,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- converted_model = convert_tf_to_tflite(self.model_path, quantized)
+ converted_model = convert_to_tflite(self.model_path, quantized)
save_tflite_model(converted_model, tflite_model_path)
return TFLiteModel(tflite_model_path)
@@ -109,7 +107,7 @@ def get_model(model: str | Path) -> ModelConfiguration:
if is_keras_model(model):
return KerasModel(model)
- if is_tf_model(model):
+ if is_saved_model(model):
return TfModel(model)
raise Exception(
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
new file mode 100644
index 0000000..960a5c3
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -0,0 +1,132 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Functions for checking TensorFlow Lite compatibility."""
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass
+from enum import auto
+from enum import Enum
+from typing import Any
+from typing import cast
+from typing import List
+
+from tensorflow.lite.python import convert
+from tensorflow.lite.python.metrics import converter_error_data_pb2
+
+from mlia.nn.tensorflow.utils import get_tflite_converter
+from mlia.utils.logging import redirect_raw_output
+
+
+logger = logging.getLogger(__name__)
+
+
+class TFLiteConversionErrorCode(Enum):
+ """TensorFlow Lite conversion error codes."""
+
+ NEEDS_FLEX_OPS = auto()
+ NEEDS_CUSTOM_OPS = auto()
+ UNSUPPORTED_CONTROL_FLOW_V1 = auto()
+ GPU_NOT_COMPATIBLE = auto()
+ UNKNOWN = auto()
+
+
+@dataclass
+class TFLiteConversionError:
+ """TensorFlow Lite conversion error details."""
+
+ message: str
+ code: TFLiteConversionErrorCode
+ operator: str
+ location: list[str]
+
+
+@dataclass
+class TFLiteCompatibilityInfo:
+ """TensorFlow Lite compatibility information."""
+
+ compatible: bool
+ conversion_exception: Exception | None = None
+ conversion_errors: list[TFLiteConversionError] | None = None
+
+ def unsupported_ops_by_code(self, code: TFLiteConversionErrorCode) -> list[str]:
+ """Filter unsupported operators by error code."""
+ if not self.conversion_errors:
+ return []
+
+ return [err.operator for err in self.conversion_errors if err.code == code]
+
+
+class TFLiteChecker:
+ """Class for checking TensorFlow Lite compatibility."""
+
+ def __init__(self, quantized: bool = False) -> None:
+ """Init TensorFlow Lite checker."""
+ self.quantized = quantized
+
+ def check_compatibility(self, model: Any) -> TFLiteCompatibilityInfo:
+ """Check TensorFlow Lite compatibility for the provided model."""
+ try:
+ logger.debug("Check TensorFlow Lite compatibility for %s", model)
+ converter = get_tflite_converter(model, quantized=self.quantized)
+
+ # there is an issue with intercepting TensorFlow output
+ # not all output could be captured, for now just intercept
+ # stderr output
+ with redirect_raw_output(
+ logging.getLogger("tensorflow"), stdout_level=None
+ ):
+ converter.convert()
+ except convert.ConverterError as err:
+ return self._process_exception(err)
+ except Exception as err: # pylint: disable=broad-except
+ return TFLiteCompatibilityInfo(compatible=False, conversion_exception=err)
+ else:
+ return TFLiteCompatibilityInfo(compatible=True)
+
+ def _process_exception(
+ self, err: convert.ConverterError
+ ) -> TFLiteCompatibilityInfo:
+ """Parse error details if possible."""
+ conversion_errors = None
+ if hasattr(err, "errors"):
+ conversion_errors = [
+ TFLiteConversionError(
+ message=error.error_message.splitlines()[0],
+ code=self._convert_error_code(error.error_code),
+ operator=error.operator.name,
+ location=cast(
+ List[str],
+ [loc.name for loc in error.location.call if loc.name]
+ if hasattr(error, "location")
+ else [],
+ ),
+ )
+ for error in err.errors
+ ]
+
+ return TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=conversion_errors,
+ )
+
+ @staticmethod
+ def _convert_error_code(code: int) -> TFLiteConversionErrorCode:
+ """Convert internal error codes."""
+ # pylint: disable=no-member
+ error_data = converter_error_data_pb2.ConverterErrorData
+ if code == error_data.ERROR_NEEDS_FLEX_OPS:
+ return TFLiteConversionErrorCode.NEEDS_FLEX_OPS
+
+ if code == error_data.ERROR_NEEDS_CUSTOM_OPS:
+ return TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS
+
+ if code == error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1:
+ return TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1
+
+ if code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE:
+ return TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE
+ # pylint: enable=no-member
+
+ return TFLiteConversionErrorCode.UNKNOWN
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 7970329..287e6ff 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -6,143 +6,122 @@ from __future__ import annotations
import logging
from pathlib import Path
+from typing import Any
from typing import Callable
+from typing import cast
from typing import Iterable
import numpy as np
import tensorflow as tf
-from tensorflow.lite.python.interpreter import Interpreter
from mlia.utils.logging import redirect_output
-def representative_dataset(model: tf.keras.Model) -> Callable:
+def representative_dataset(
+ input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
+) -> Callable:
"""Sample dataset used for quantization."""
- input_shape = model.input_shape
+ if input_shape[0] != 1:
+ raise Exception("Only the input batch_size=1 is supported!")
def dataset() -> Iterable:
- for _ in range(100):
- if input_shape[0] != 1:
- raise Exception("Only the input batch_size=1 is supported!")
+ for _ in range(sample_count):
data = np.random.rand(*input_shape)
- yield [data.astype(np.float32)]
+ yield [data.astype(input_dtype)]
return dataset
def get_tf_tensor_shape(model: str) -> list:
"""Get input shape for the TensorFlow tensor model."""
- # Loading the model
loaded = tf.saved_model.load(model)
- # The model signature must have 'serving_default' as a key
- if "serving_default" not in loaded.signatures.keys():
- raise Exception(
- "Unsupported TensorFlow model signature, must have 'serving_default'"
- )
- # Get the signature inputs
- inputs_tensor_info = loaded.signatures["serving_default"].inputs
- dims = []
- # Build a list of all inputs shape sizes
- for input_key in inputs_tensor_info:
- if input_key.get_shape():
- dims.extend(list(input_key.get_shape()))
- return dims
-
-
-def representative_tf_dataset(model: str) -> Callable:
- """Sample dataset used for quantization."""
- if not (input_shape := get_tf_tensor_shape(model)):
- raise Exception("Unable to get input shape")
- def dataset() -> Iterable:
- for _ in range(100):
- data = np.random.rand(*input_shape)
- yield [data.astype(np.float32)]
+ try:
+ default_signature_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ default_signature = loaded.signatures[default_signature_key]
+ inputs_tensor_info = default_signature.inputs
+ except KeyError as err:
+ raise Exception(f"Signature '{default_signature_key}' not found") from err
- return dataset
+ return [
+ dim
+ for input_key in inputs_tensor_info
+ if (shape := input_key.get_shape())
+ for dim in shape
+ ]
-def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
+def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes:
"""Convert Keras model to TensorFlow Lite."""
- if not isinstance(model, tf.keras.Model):
- raise Exception("Invalid model type")
-
- converter = tf.lite.TFLiteConverter.from_keras_model(model)
-
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset(model)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
+ converter = get_tflite_converter(model, quantized)
with redirect_output(logging.getLogger("tensorflow")):
- tflite_model = converter.convert()
-
- return tflite_model
-
-
-def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
- """Convert TensorFlow model to TensorFlow Lite."""
- if not isinstance(model, str):
- raise Exception("Invalid model type")
-
- converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ return cast(bytes, converter.convert())
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_tf_dataset(model)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- with redirect_output(logging.getLogger("tensorflow")):
- tflite_model = converter.convert()
-
- return tflite_model
-
-
-def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None:
+def save_keras_model(
+ model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
+) -> None:
"""Save Keras model at provided path."""
- # Checkpoint: saving the optimizer is necessary.
- model.save(save_path, include_optimizer=True)
+ model.save(save_path, include_optimizer=include_optimizer)
-def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None:
+def save_tflite_model(tflite_model: bytes, save_path: str | Path) -> None:
"""Save TensorFlow Lite model at provided path."""
with open(save_path, "wb") as file:
- file.write(model)
+ file.write(tflite_model)
def is_tflite_model(model: str | Path) -> bool:
- """Check if model type is supported by TensorFlow Lite API.
-
- TensorFlow Lite model is indicated by the model file extension .tflite
- """
+ """Check if path contains TensorFlow Lite model."""
model_path = Path(model)
+
return model_path.suffix == ".tflite"
def is_keras_model(model: str | Path) -> bool:
- """Check if model type is supported by Keras API.
-
- Keras model is indicated by:
- 1. if it's a directory (meaning saved model),
- it should contain keras_metadata.pb file
- 2. or if the model file extension is .h5/.hdf5
- """
+ """Check if path contains a Keras model."""
model_path = Path(model)
if model_path.is_dir():
- return (model_path / "keras_metadata.pb").exists()
- return model_path.suffix in (".h5", ".hdf5")
+ return model_path.joinpath("keras_metadata.pb").exists()
+ return model_path.suffix in (".h5", ".hdf5")
-def is_tf_model(model: str | Path) -> bool:
- """Check if model type is supported by TensorFlow API.
- TensorFlow model is indicated if its directory (meaning saved model)
- doesn't contain keras_metadata.pb file
- """
+def is_saved_model(model: str | Path) -> bool:
+ """Check if path contains SavedModel model."""
model_path = Path(model)
+
return model_path.is_dir() and not is_keras_model(model)
+
+
+def get_tflite_converter(
+ model: tf.keras.Model | str | Path, quantized: bool = False
+) -> tf.lite.TFLiteConverter:
+ """Configure TensorFlow Lite converter for the provided model."""
+ if isinstance(model, (str, Path)):
+ # converter's methods accept string as input parameter
+ model = str(model)
+
+ if isinstance(model, tf.keras.Model):
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ input_shape = model.input_shape
+ elif isinstance(model, str) and is_saved_model(model):
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ input_shape = get_tf_tensor_shape(model)
+ elif isinstance(model, str) and is_keras_model(model):
+ keras_model = tf.keras.models.load_model(model)
+ input_shape = keras_model.input_shape
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ else:
+ raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(input_shape)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ return converter
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
index 793500a..cf7ad27 100644
--- a/src/mlia/utils/logging.py
+++ b/src/mlia/utils/logging.py
@@ -4,6 +4,9 @@
from __future__ import annotations
import logging
+import os
+import sys
+import tempfile
from contextlib import contextmanager
from contextlib import ExitStack
from contextlib import redirect_stderr
@@ -12,6 +15,8 @@ from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generator
+from typing import Iterable
+from typing import TextIO
class LoggerWriter:
@@ -35,7 +40,7 @@ class LoggerWriter:
def redirect_output(
logger: logging.Logger,
stdout_level: int = logging.INFO,
- stderr_level: int = logging.INFO,
+ stderr_level: int = logging.ERROR,
) -> Generator[None, None, None]:
"""Redirect standard output to the logger."""
stdout_to_log = LoggerWriter(logger, stdout_level)
@@ -48,6 +53,47 @@ def redirect_output(
yield
+@contextmanager
+def redirect_raw(
+ logger: logging.Logger, output: TextIO, log_level: int
+) -> Generator[None, None, None]:
+ """Redirect output using file descriptors."""
+ with tempfile.TemporaryFile(mode="r+") as tmp:
+ old_output_fd: int | None = None
+ try:
+ output_fd = output.fileno()
+ old_output_fd = os.dup(output_fd)
+ os.dup2(tmp.fileno(), output_fd)
+
+ yield
+ finally:
+ if old_output_fd is not None:
+ os.dup2(old_output_fd, output_fd)
+ os.close(old_output_fd)
+
+ tmp.seek(0)
+ for line in tmp.readlines():
+ logger.log(log_level, line.rstrip())
+
+
+@contextmanager
+def redirect_raw_output(
+ logger: logging.Logger,
+ stdout_level: int | None = logging.INFO,
+ stderr_level: int | None = logging.ERROR,
+) -> Generator[None, None, None]:
+ """Redirect output on the process level."""
+ with ExitStack() as exit_stack:
+ for level, output in [
+ (stdout_level, sys.stdout),
+ (stderr_level, sys.stderr),
+ ]:
+ if level is not None:
+ exit_stack.enter_context(redirect_raw(logger, output, level))
+
+ yield
+
+
class LogFilter(logging.Filter):
"""Configurable log filter."""
@@ -112,9 +158,19 @@ def create_log_handler(
def attach_handlers(
- handlers: list[logging.Handler], loggers: list[logging.Logger]
+ handlers: Iterable[logging.Handler], loggers: Iterable[logging.Logger]
) -> None:
"""Attach handlers to the loggers."""
for handler in handlers:
for logger in loggers:
logger.addHandler(handler)
+
+
+@contextmanager
+def log_action(action: str) -> Generator[None, None, None]:
+ """Log action."""
+ logger = logging.getLogger(__name__)
+
+ logger.info(action)
+ yield
+ logger.info("Done\n")
diff --git a/tests/test_devices_cortex_a_data_analysis.py b/tests/test_devices_cortex_a_data_analysis.py
deleted file mode 100644
index 4724c81..0000000
--- a/tests/test_devices_cortex_a_data_analysis.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Tests for Cortex-A data analysis module."""
-from __future__ import annotations
-
-import pytest
-
-from mlia.core.common import DataItem
-from mlia.core.data_analysis import Fact
-from mlia.devices.cortexa.data_analysis import CortexADataAnalyzer
-from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible
-from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
-from mlia.devices.cortexa.operators import CortexACompatibilityInfo
-
-
-@pytest.mark.parametrize(
- "input_data, expected_facts",
- [
- [
- CortexACompatibilityInfo(True, []),
- [ModelIsCortexACompatible()],
- ],
- [
- CortexACompatibilityInfo(False, []),
- [ModelIsNotCortexACompatible()],
- ],
- ],
-)
-def test_cortex_a_data_analyzer(
- input_data: DataItem, expected_facts: list[Fact]
-) -> None:
- """Test Cortex-A data analyzer."""
- analyzer = CortexADataAnalyzer()
- analyzer.analyze_data(input_data)
- assert analyzer.get_analyzed_data() == expected_facts
diff --git a/tests/test_devices_cortex_a_advice_generation.py b/tests/test_devices_cortexa_advice_generation.py
index 69529d4..ead8ae6 100644
--- a/tests/test_devices_cortex_a_advice_generation.py
+++ b/tests/test_devices_cortexa_advice_generation.py
@@ -12,6 +12,7 @@ from mlia.core.context import ExecutionContext
from mlia.devices.cortexa.advice_generation import CortexAAdviceProducer
from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible
from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
+from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
@pytest.mark.parametrize(
@@ -34,6 +35,43 @@ from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
AdviceCategory.OPERATORS,
[Advice(["Model is fully compatible with Cortex-A."])],
],
+ [
+ ModelIsNotTFLiteCompatible(
+ flex_ops=["flex_op1", "flex_op2"],
+ custom_ops=["custom_op1", "custom_op2"],
+ ),
+ AdviceCategory.OPERATORS,
+ [
+ Advice(
+ [
+ "The following operators are not natively "
+ "supported by TensorFlow Lite: flex_op1, flex_op2.",
+ "Please refer to the TensorFlow documentation for "
+ "more details.",
+ ]
+ ),
+ Advice(
+ [
+ "The following operators are custom and not natively "
+ "supported by TensorFlow Lite: custom_op1, custom_op2.",
+ "Please refer to the TensorFlow documentation for "
+ "more details.",
+ ]
+ ),
+ ],
+ ],
+ [
+ ModelIsNotTFLiteCompatible(),
+ AdviceCategory.OPERATORS,
+ [
+ Advice(
+ [
+ "Model could not be converted into TensorFlow Lite format.",
+ "Please refer to the table for more details.",
+ ]
+ ),
+ ],
+ ],
],
)
def test_cortex_a_advice_producer(
diff --git a/tests/test_devices_cortexa_data_analysis.py b/tests/test_devices_cortexa_data_analysis.py
new file mode 100644
index 0000000..b491e52
--- /dev/null
+++ b/tests/test_devices_cortexa_data_analysis.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Cortex-A data analysis module."""
+from __future__ import annotations
+
+import pytest
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.devices.cortexa.data_analysis import CortexADataAnalyzer
+from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible
+from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible
+from mlia.devices.cortexa.data_analysis import ModelIsNotTFLiteCompatible
+from mlia.devices.cortexa.operators import CortexACompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_facts",
+ [
+ [
+ CortexACompatibilityInfo(True, []),
+ [ModelIsCortexACompatible()],
+ ],
+ [
+ CortexACompatibilityInfo(False, []),
+ [ModelIsNotCortexACompatible()],
+ ],
+ [
+ TFLiteCompatibilityInfo(compatible=True),
+ [],
+ ],
+ [
+ TFLiteCompatibilityInfo(compatible=False),
+ [ModelIsNotTFLiteCompatible(custom_ops=[], flex_ops=[])],
+ ],
+ [
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_errors=[
+ TFLiteConversionError(
+ "error",
+ TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS,
+ "custom_op1",
+ [],
+ ),
+ TFLiteConversionError(
+ "error",
+ TFLiteConversionErrorCode.NEEDS_FLEX_OPS,
+ "flex_op1",
+ [],
+ ),
+ ],
+ ),
+ [
+ ModelIsNotTFLiteCompatible(
+ custom_ops=["custom_op1"],
+ flex_ops=["flex_op1"],
+ )
+ ],
+ ],
+ ],
+)
+def test_cortex_a_data_analyzer(
+ input_data: DataItem, expected_facts: list[Fact]
+) -> None:
+ """Test Cortex-A data analyzer."""
+ analyzer = CortexADataAnalyzer()
+ analyzer.analyze_data(input_data)
+ assert analyzer.get_analyzed_data() == expected_facts
diff --git a/tests/test_devices_cortex_a_data_collection.py b/tests/test_devices_cortexa_data_collection.py
index 7ea3e52..7ea3e52 100644
--- a/tests/test_devices_cortex_a_data_collection.py
+++ b/tests/test_devices_cortexa_data_collection.py
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py
new file mode 100644
index 0000000..c330fdb
--- /dev/null
+++ b/tests/test_nn_tensorflow_tflite_compat.py
@@ -0,0 +1,210 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for tflite_compat module."""
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python import convert
+from tensorflow.lite.python.metrics import converter_error_data_pb2
+
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
+
+
+def test_not_fully_compatible_model_flex_ops() -> None:
+ """Test models that requires TF_SELECT_OPS."""
+ model = tf.keras.models.Sequential(
+ [
+ tf.keras.layers.Dense(units=1, input_shape=[1], batch_size=1),
+ tf.keras.layers.Dense(units=16, activation="gelu"),
+ tf.keras.layers.Dense(units=1),
+ ]
+ )
+
+ checker = TFLiteChecker()
+ result = checker.check_compatibility(model)
+
+ assert result.compatible is False
+ assert isinstance(result.conversion_exception, convert.ConverterError)
+ assert result.conversion_errors is not None
+ assert len(result.conversion_errors) == 1
+
+ conv_err = result.conversion_errors[0]
+ assert isinstance(conv_err, TFLiteConversionError)
+ assert conv_err.message == "'tf.Erf' op is neither a custom op nor a flex op"
+ assert conv_err.code == TFLiteConversionErrorCode.NEEDS_FLEX_OPS
+ assert conv_err.operator == "tf.Erf"
+ assert len(conv_err.location) == 3
+
+
+def _get_tflite_conversion_error(
+ error_message: str = "Conversion error",
+ custom_op: bool = False,
+ flex_op: bool = False,
+ unsupported_flow_v1: bool = False,
+ gpu_not_compatible: bool = False,
+ unknown_reason: bool = False,
+) -> convert.ConverterError:
+ """Create TensorFlow Lite conversion error."""
+ error_data = converter_error_data_pb2.ConverterErrorData
+ convert_error = convert.ConverterError(error_message)
+
+ # pylint: disable=no-member
+ def _add_error(operator: str, error_code: int) -> None:
+ convert_error.append_error(
+ error_data(
+ operator=error_data.Operator(name=operator),
+ error_code=error_code,
+ error_message=error_message,
+ )
+ )
+
+ if custom_op:
+ _add_error("custom_op", error_data.ERROR_NEEDS_CUSTOM_OPS)
+
+ if flex_op:
+ _add_error("flex_op", error_data.ERROR_NEEDS_FLEX_OPS)
+
+ if unsupported_flow_v1:
+ _add_error("flow_op", error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1)
+
+ if gpu_not_compatible:
+ _add_error("non_gpu_op", error_data.ERROR_GPU_NOT_COMPATIBLE)
+
+ if unknown_reason:
+ _add_error("unknown_op", None) # type: ignore
+ # pylint: enable=no-member
+
+ return convert_error
+
+
+# pylint: disable=undefined-variable,unused-variable
+@pytest.mark.parametrize(
+ "conversion_error, expected_result",
+ [
+ (None, TFLiteCompatibilityInfo(compatible=True)),
+ (
+ err := _get_tflite_conversion_error(custom_op=True),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS,
+ operator="custom_op",
+ location=[],
+ )
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(flex_op=True),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_FLEX_OPS,
+ operator="flex_op",
+ location=[],
+ )
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(unknown_reason=True),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.UNKNOWN,
+ operator="unknown_op",
+ location=[],
+ )
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(
+ flex_op=True,
+ custom_op=True,
+ gpu_not_compatible=True,
+ unsupported_flow_v1=True,
+ ),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS,
+ operator="custom_op",
+ location=[],
+ ),
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_FLEX_OPS,
+ operator="flex_op",
+ location=[],
+ ),
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1,
+ operator="flow_op",
+ location=[],
+ ),
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE,
+ operator="non_gpu_op",
+ location=[],
+ ),
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[],
+ ),
+ ),
+ (
+ err := ValueError("Some unknown issue"),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ ),
+ ),
+ ],
+)
+# pylint: enable=undefined-variable,unused-variable
+def test_tflite_compatibility(
+ conversion_error: convert.ConverterError | ValueError | None,
+ expected_result: TFLiteCompatibilityInfo,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test TensorFlow Lite compatibility."""
+ converter_mock = MagicMock()
+
+ if conversion_error is not None:
+ converter_mock.convert.side_effect = conversion_error
+
+ monkeypatch.setattr(
+ "mlia.nn.tensorflow.tflite_compat.get_tflite_converter",
+ lambda *args, **kwargs: converter_mock,
+ )
+
+ checker = TFLiteChecker()
+ result = checker.check_compatibility(MagicMock())
+ assert result == expected_result
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py
index 199c7db..5131171 100644
--- a/tests/test_nn_tensorflow_utils.py
+++ b/tests/test_nn_tensorflow_utils.py
@@ -3,6 +3,7 @@
"""Test for module utils/test_utils."""
from pathlib import Path
+import numpy as np
import pytest
import tensorflow as tf
@@ -10,16 +11,43 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import get_tf_tensor_shape
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import representative_dataset
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
-def test_convert_to_tflite(test_keras_model: Path) -> None:
+def test_generate_representative_dataset() -> None:
+ """Test function for generating representative dataset."""
+ dataset = representative_dataset([1, 3, 3], 5)
+ data = list(dataset())
+
+ assert len(data) == 5
+ for elem in data:
+ assert isinstance(elem, list)
+ assert len(elem) == 1
+
+ ndarray = elem[0]
+ assert ndarray.dtype == np.float32
+ assert isinstance(ndarray, np.ndarray)
+
+
+def test_generate_representative_dataset_wrong_shape() -> None:
+ """Test that only shape with batch size=1 is supported."""
+ with pytest.raises(Exception, match="Only the input batch_size=1 is supported!"):
+ representative_dataset([2, 3, 3], 5)
+
+
+def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None:
+ """Test converting SavedModel to TensorFlow Lite."""
+ result = convert_to_tflite(test_tf_model.as_posix())
+ assert isinstance(result, bytes)
+
+
+def test_convert_keras_to_tflite(test_keras_model: Path) -> None:
"""Test converting Keras model to TensorFlow Lite."""
keras_model = tf.keras.models.load_model(str(test_keras_model))
- tflite_model = convert_to_tflite(keras_model)
-
- assert tflite_model
+ result = convert_to_tflite(keras_model)
+ assert isinstance(result, bytes)
def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None:
@@ -46,6 +74,14 @@ def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
assert interpreter
+def test_convert_unknown_model_to_tflite() -> None:
+ """Test that unknown model type cannot be converted to TensorFlow Lite."""
+ with pytest.raises(
+ ValueError, match="Unable to create TensorFlow Lite converter for 123"
+ ):
+ convert_to_tflite(123)
+
+
@pytest.mark.parametrize(
"model_path, expected_result",
[
diff --git a/tests/test_utils_logging.py b/tests/test_utils_logging.py
index 1e212b2..ac835c6 100644
--- a/tests/test_utils_logging.py
+++ b/tests/test_utils_logging.py
@@ -8,10 +8,14 @@ import sys
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import Callable
+from unittest.mock import MagicMock
import pytest
-from mlia.cli.logging import create_log_handler
+from mlia.utils.logging import create_log_handler
+from mlia.utils.logging import redirect_output
+from mlia.utils.logging import redirect_raw_output
@pytest.mark.parametrize(
@@ -62,3 +66,21 @@ def test_create_log_handler(
delay=delay,
)
assert isinstance(handler, expected_class)
+
+
+@pytest.mark.parametrize(
+ "redirect_context_manager",
+ [
+ redirect_raw_output,
+ redirect_output,
+ ],
+)
+def test_output_redirection(redirect_context_manager: Callable) -> None:
+ """Test output redirection via context manager."""
+ print("before redirect")
+ logger_mock = MagicMock()
+ with redirect_context_manager(logger_mock):
+ print("output redirected")
+ print("after redirect")
+
+ logger_mock.log.assert_called_once_with(logging.INFO, "output redirected")