From edf436c48029aa4e2b4ca5d17eee5a8f07ecbd6f Mon Sep 17 00:00:00 2001 From: Dhruv Chauhan Date: Mon, 20 Mar 2023 10:22:08 +0000 Subject: MLIA-711 Extend TensorFlow Lite Compatibility Check - Unify the TensorFlow Lite compatibility check across Cortex-A, TOSA and Ethos-U targets - Display tables/messages with parsed information - Do not display raw TensorFlow Lite errors, and return with exit code 0 Change-Id: I9333fdb6cbe592f1ed7395d392412168492a1479 --- src/mlia/target/common/reporters.py | 159 ++++++++++++++++++++++++ src/mlia/target/cortex_a/advice_generation.py | 53 ++------ src/mlia/target/cortex_a/data_analysis.py | 36 +----- src/mlia/target/cortex_a/reporters.py | 54 +------- src/mlia/target/ethos_u/advice_generation.py | 20 +++ src/mlia/target/ethos_u/data_analysis.py | 9 +- src/mlia/target/ethos_u/data_collection.py | 13 +- src/mlia/target/ethos_u/handlers.py | 4 + src/mlia/target/ethos_u/reporters.py | 34 +++-- src/mlia/target/tosa/advice_generation.py | 20 +++ src/mlia/target/tosa/data_analysis.py | 11 +- src/mlia/target/tosa/data_collection.py | 17 ++- src/mlia/target/tosa/handlers.py | 4 + src/mlia/target/tosa/reporters.py | 5 + tests/test_target_cortex_a_advice_generation.py | 6 +- tests/test_target_cortex_a_data_analysis.py | 6 +- tests/test_target_ethos_u_data_analysis.py | 52 ++++++++ tests/test_target_ethos_u_reporters.py | 24 ++++ tests/test_target_tosa_data_analysis.py | 54 +++++++- 19 files changed, 425 insertions(+), 156 deletions(-) create mode 100644 src/mlia/target/common/reporters.py diff --git a/src/mlia/target/common/reporters.py b/src/mlia/target/common/reporters.py new file mode 100644 index 0000000..366e154 --- /dev/null +++ b/src/mlia/target/common/reporters.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Common reports module.""" +from __future__ import annotations + +from dataclasses import dataclass + +from mlia.core.data_analysis import Fact +from mlia.core.reporting import Column +from mlia.core.reporting import Format +from mlia.core.reporting import Report +from mlia.core.reporting import Table +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo + + +@dataclass +class ModelIsNotTFLiteCompatible(Fact): + """Model could not be converted into TensorFlow Lite format.""" + + custom_ops: list[str] | None = None + flex_ops: list[str] | None = None + + +@dataclass +class TFLiteCompatibilityCheckFailed(Fact): + """TensorFlow Lite compatibility check failed by unknown reason.""" + + +@dataclass +class ModelHasCustomOperators(Fact): + """Model could not be loaded because it contains custom ops.""" + + +def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report: + """Generate report for the TensorFlow Lite compatibility information.""" + if compat_info.conversion_errors: + return Table( + [ + Column("#", only_for=["plain_text"]), + Column("Operator", alias="operator"), + Column( + "Operator location", + alias="operator_location", + fmt=Format(wrap_width=25), + ), + Column("Error code", alias="error_code"), + Column( + "Error message", alias="error_message", fmt=Format(wrap_width=25) + ), + ], + [ + ( + index + 1, + err.operator, + ", ".join(err.location), + err.code.name, + err.message, + ) + for index, err in enumerate(compat_info.conversion_errors) + ], + name="TensorFlow Lite conversion errors", + alias="tensorflow_lite_conversion_errors", + ) + + return Table( + columns=[ + Column("Reason", alias="reason"), + Column( + "Exception details", + alias="exception_details", + fmt=Format(wrap_width=40), + ), + ], + rows=[ + ( + "TensorFlow Lite compatibility check failed with exception", + str(compat_info.conversion_exception), + ), + ], + name="TensorFlow Lite compatibility errors", + alias="tflite_compatibility", + ) + + +def handle_model_is_not_tflite_compatible_common( # type: ignore + self, data_item: ModelIsNotTFLiteCompatible +) -> None: + """Advice for TensorFlow Lite compatibility.""" + if data_item.flex_ops: + self.add_advice( + [ + "The following operators are not natively " + "supported by TensorFlow Lite: " + f"{', '.join(data_item.flex_ops)}.", + "Using select TensorFlow operators in TensorFlow Lite model " + "requires special initialization of TFLiteConverter and " + "TensorFlow Lite run-time.", + "Please refer to the TensorFlow documentation for more " + "details: https://www.tensorflow.org/lite/guide/ops_select", + "Note, such models are not supported by the ML Inference Advisor.", + ] + ) + + if data_item.custom_ops: + self.add_advice( + [ + "The following operators appear to be custom and not natively " + "supported by TensorFlow Lite: " + f"{', '.join(data_item.custom_ops)}.", + "Using custom operators in TensorFlow Lite model " + "requires special initialization of TFLiteConverter and " + "TensorFlow Lite run-time.", + "Please refer to the TensorFlow documentation for more " + "details: https://www.tensorflow.org/lite/guide/ops_custom", + "Note, such models are not supported by the ML Inference Advisor.", + ] + ) + + if not data_item.flex_ops and not data_item.custom_ops: + self.add_advice( + [ + "Model could not be converted into TensorFlow Lite format.", + "Please refer to the table for more details.", + ] + ) + + +def handle_tflite_check_failed_common( # type: ignore + self, _data_item: TFLiteCompatibilityCheckFailed +) -> None: + """Advice for the failed TensorFlow Lite compatibility checks.""" + self.add_advice( + [ + "Model could not be converted into TensorFlow Lite format.", + "Please refer to the table for more details.", + ] + ) + + +def analyze_tflite_compatibility_common( # type: ignore + self, data_item: TFLiteCompatibilityInfo +) -> None: + """Analyze TensorFlow Lite compatibility information.""" + if data_item.compatible: + return + + if data_item.conversion_failed_with_errors: + self.add_fact( + ModelIsNotTFLiteCompatible( + custom_ops=data_item.required_custom_ops, + flex_ops=data_item.required_flex_ops, + ) + ) + + if data_item.check_failed_with_unknown_error: + self.add_fact(TFLiteCompatibilityCheckFailed()) + + if data_item.conversion_failed_for_model_with_custom_ops: + self.add_fact(ModelHasCustomOperators()) diff --git a/src/mlia/target/cortex_a/advice_generation.py b/src/mlia/target/cortex_a/advice_generation.py index 98e8c06..1011d6c 100644 --- a/src/mlia/target/cortex_a/advice_generation.py +++ b/src/mlia/target/cortex_a/advice_generation.py @@ -7,11 +7,13 @@ from mlia.core.advice_generation import advice_category from mlia.core.advice_generation import FactBasedAdviceProducer from mlia.core.common import AdviceCategory from mlia.core.common import DataItem -from mlia.target.cortex_a.data_analysis import ModelHasCustomOperators +from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common +from mlia.target.common.reporters import handle_tflite_check_failed_common +from mlia.target.common.reporters import ModelHasCustomOperators +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible -from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible -from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed class CortexAAdviceProducer(FactBasedAdviceProducer): @@ -88,43 +90,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): self, data_item: ModelIsNotTFLiteCompatible ) -> None: """Advice for TensorFlow Lite compatibility.""" - if data_item.flex_ops: - self.add_advice( - [ - "The following operators are not natively " - "supported by TensorFlow Lite: " - f"{', '.join(data_item.flex_ops)}.", - "Using select TensorFlow operators in TensorFlow Lite model " - "requires special initialization of TFLiteConverter and " - "TensorFlow Lite run-time.", - "Please refer to the TensorFlow documentation for more " - "details: https://www.tensorflow.org/lite/guide/ops_select", - "Note, such models are not supported by the ML Inference Advisor.", - ] - ) - - if data_item.custom_ops: - self.add_advice( - [ - "The following operators appear to be custom and not natively " - "supported by TensorFlow Lite: " - f"{', '.join(data_item.custom_ops)}.", - "Using custom operators in TensorFlow Lite model " - "requires special initialization of TFLiteConverter and " - "TensorFlow Lite run-time.", - "Please refer to the TensorFlow documentation for more " - "details: https://www.tensorflow.org/lite/guide/ops_custom", - "Note, such models are not supported by the ML Inference Advisor.", - ] - ) - - if not data_item.flex_ops and not data_item.custom_ops: - self.add_advice( - [ - "Model could not be converted into TensorFlow Lite format.", - "Please refer to the table for more details.", - ] - ) + handle_model_is_not_tflite_compatible_common(self, data_item) @produce_advice.register @advice_category(AdviceCategory.COMPATIBILITY) @@ -132,12 +98,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): self, _data_item: TFLiteCompatibilityCheckFailed ) -> None: """Advice for the failed TensorFlow Lite compatibility checks.""" - self.add_advice( - [ - "Model could not be converted into TensorFlow Lite format.", - "Please refer to the table for more details.", - ] - ) + handle_tflite_check_failed_common(self, _data_item) @produce_advice.register @advice_category(AdviceCategory.COMPATIBILITY) diff --git a/src/mlia/target/cortex_a/data_analysis.py b/src/mlia/target/cortex_a/data_analysis.py index 089c1a2..3161618 100644 --- a/src/mlia/target/cortex_a/data_analysis.py +++ b/src/mlia/target/cortex_a/data_analysis.py @@ -12,6 +12,7 @@ from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import analyze_tflite_compatibility_common from mlia.target.cortex_a.operators import CortexACompatibilityInfo @@ -64,22 +65,7 @@ class CortexADataAnalyzer(FactExtractor): @analyze_data.register def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None: """Analyze TensorFlow Lite compatibility information.""" - if data_item.compatible: - return - - if data_item.conversion_failed_with_errors: - self.add_fact( - ModelIsNotTFLiteCompatible( - custom_ops=data_item.required_custom_ops, - flex_ops=data_item.required_flex_ops, - ) - ) - - if data_item.check_failed_with_unknown_error: - self.add_fact(TFLiteCompatibilityCheckFailed()) - - if data_item.conversion_failed_for_model_with_custom_ops: - self.add_fact(ModelHasCustomOperators()) + analyze_tflite_compatibility_common(self, data_item) @dataclass @@ -107,21 +93,3 @@ class ModelIsNotCortexACompatible(CortexACompatibility): unsupported_ops: set[str] activation_func_support: dict[str, ActivationFunctionSupport] - - -@dataclass -class ModelIsNotTFLiteCompatible(Fact): - """Model could not be converted into TensorFlow Lite format.""" - - custom_ops: list[str] | None = None - flex_ops: list[str] | None = None - - -@dataclass -class TFLiteCompatibilityCheckFailed(Fact): - """TensorFlow Lite compatibility check failed by unknown reason.""" - - -@dataclass -class ModelHasCustomOperators(Fact): - """Model could not be loaded because it contains custom ops.""" diff --git a/src/mlia/target/cortex_a/reporters.py b/src/mlia/target/cortex_a/reporters.py index 65d7906..fc80c9f 100644 --- a/src/mlia/target/cortex_a/reporters.py +++ b/src/mlia/target/cortex_a/reporters.py @@ -17,6 +17,7 @@ from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import report_tflite_compatiblity from mlia.target.cortex_a.config import CortexAConfiguration from mlia.target.cortex_a.operators import CortexACompatibilityInfo from mlia.utils.console import style_improvement @@ -34,57 +35,6 @@ def report_target(target_config: CortexAConfiguration) -> Report: ) -def report_tflite_compatiblity(compat_info: TFLiteCompatibilityInfo) -> Report: - """Generate report for the TensorFlow Lite compatibility information.""" - if compat_info.conversion_errors: - return Table( - [ - Column("#", only_for=["plain_text"]), - Column("Operator", alias="operator"), - Column( - "Operator location", - alias="operator_location", - fmt=Format(wrap_width=25), - ), - Column("Error code", alias="error_code"), - Column( - "Error message", alias="error_message", fmt=Format(wrap_width=25) - ), - ], - [ - ( - index + 1, - err.operator, - ", ".join(err.location), - err.code.name, - err.message, - ) - for index, err in enumerate(compat_info.conversion_errors) - ], - name="TensorFlow Lite conversion errors", - alias="tensorflow_lite_conversion_errors", - ) - - return Table( - columns=[ - Column("Reason", alias="reason"), - Column( - "Exception details", - alias="exception_details", - fmt=Format(wrap_width=40), - ), - ], - rows=[ - ( - "TensorFlow Lite compatibility check failed with exception", - str(compat_info.conversion_exception), - ), - ], - name="TensorFlow Lite compatibility errors", - alias="tflite_compatibility", - ) - - def report_cortex_a_operators(op_compat: CortexACompatibilityInfo) -> Report: """Generate report for the operators.""" return Table( @@ -132,7 +82,7 @@ def cortex_a_formatters(data: Any) -> Callable[[Any], Report]: return report_target if isinstance(data, TFLiteCompatibilityInfo): - return report_tflite_compatiblity + return report_tflite_compatiblity # type: ignore if isinstance(data, CortexACompatibilityInfo): return report_cortex_a_operators diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py index daae4f4..a9f9eac 100644 --- a/src/mlia/target/ethos_u/advice_generation.py +++ b/src/mlia/target/ethos_u/advice_generation.py @@ -12,6 +12,10 @@ from mlia.core.advice_generation import FactBasedAdviceProducer from mlia.core.common import AdviceCategory from mlia.core.common import DataItem from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common +from mlia.target.common.reporters import handle_tflite_check_failed_common +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.ethos_u.data_analysis import AllOperatorsSupportedOnNPU from mlia.target.ethos_u.data_analysis import HasCPUOnlyOperators from mlia.target.ethos_u.data_analysis import HasUnsupportedOnNPUOperators @@ -147,6 +151,22 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): ] ) + @produce_advice.register + @advice_category(AdviceCategory.COMPATIBILITY) + def handle_model_is_not_tflite_compatible( + self, data_item: ModelIsNotTFLiteCompatible + ) -> None: + """Advice for TensorFlow Lite compatibility.""" + handle_model_is_not_tflite_compatible_common(self, data_item) + + @produce_advice.register + @advice_category(AdviceCategory.COMPATIBILITY) + def handle_tflite_check_failed( + self, _data_item: TFLiteCompatibilityCheckFailed + ) -> None: + """Advice for the failed TensorFlow Lite compatibility checks.""" + handle_tflite_check_failed_common(self, _data_item) + @staticmethod def get_next_optimization_targets( opt_type: list[OptimizationSettings], diff --git a/src/mlia/target/ethos_u/data_analysis.py b/src/mlia/target/ethos_u/data_analysis.py index 6b66734..3df4bff 100644 --- a/src/mlia/target/ethos_u/data_analysis.py +++ b/src/mlia/target/ethos_u/data_analysis.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U data analysis module.""" from __future__ import annotations @@ -11,6 +11,8 @@ from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import analyze_tflite_compatibility_common from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics @@ -151,3 +153,8 @@ class EthosUDataAnalyzer(FactExtractor): diffs.append(diff) self.add_fact(OptimizationResults(diffs)) + + @analyze_data.register + def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None: + """Analyze TensorFlow Lite compatibility information.""" + analyze_tflite_compatibility_common(self, data_item) diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py index 4fdfe96..8348393 100644 --- a/src/mlia/target/ethos_u/data_collection.py +++ b/src/mlia/target/ethos_u/data_collection.py @@ -17,6 +17,9 @@ from mlia.nn.tensorflow.config import get_tflite_model from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.optimizations.select import get_optimizer from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.tflite_compat import TFLiteChecker +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import save_keras_model from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import EthosUPerformanceEstimator @@ -36,8 +39,16 @@ class EthosUOperatorCompatibility(ContextAwareDataCollector): self.model = model self.target_config = target_config - def collect_data(self) -> Operators: + def collect_data(self) -> Operators | TFLiteCompatibilityInfo | None: """Collect operator compatibility information.""" + if not is_tflite_model(self.model): + with log_action("Checking TensorFlow Lite compatibility ..."): + tflite_checker = TFLiteChecker() + tflite_compat = tflite_checker.check_compatibility(self.model) + + if not tflite_compat.compatible: + return tflite_compat + tflite_model = get_tflite_model(self.model, self.context) with log_action("Checking operator compatibility ..."): diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py index b9c89e8..1b15c55 100644 --- a/src/mlia/target/ethos_u/handlers.py +++ b/src/mlia/target/ethos_u/handlers.py @@ -8,6 +8,7 @@ import logging from mlia.backend.vela.compat import Operators from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.target.ethos_u.events import EthosUAdvisorEventHandler from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics @@ -49,6 +50,9 @@ class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): space=True, ) + if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible: + self.reporter.submit(data_item, delay_print=True) + def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None: """Handle EthosUAdvisorStarted event.""" self.reporter.submit(event.target_config) diff --git a/src/mlia/target/ethos_u/reporters.py b/src/mlia/target/ethos_u/reporters.py index 2a5b5d3..4964462 100644 --- a/src/mlia/target/ethos_u/reporters.py +++ b/src/mlia/target/ethos_u/reporters.py @@ -23,6 +23,8 @@ from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import SingleRow from mlia.core.reporting import Table +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import report_tflite_compatiblity from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import PerformanceMetrics from mlia.utils.console import style_improvement @@ -363,23 +365,31 @@ def report_perf_metrics( def ethos_u_formatters(data: Any) -> Callable[[Any], Report]: """Find appropriate formatter for the provided data.""" + report: Callable[[Any], Report] | None = None + if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2): - return report_perf_metrics + report = report_perf_metrics - if is_list_of(data, Advice): - return report_advice + elif is_list_of(data, Advice): + report = report_advice - if is_list_of(data, Operator): - return report_operators + elif is_list_of(data, Operator): + report = report_operators - if isinstance(data, Operators): - return report_operators_stat + elif isinstance(data, Operators): + report = report_operators_stat - if isinstance(data, EthosUConfiguration): - return report_target_details + elif isinstance(data, EthosUConfiguration): + report = report_target_details - if isinstance(data, (list, tuple)): + elif isinstance(data, (list, tuple)): formatters = [ethos_u_formatters(item) for item in data] - return CompoundFormatter(formatters) + report = CompoundFormatter(formatters) + + elif isinstance(data, TFLiteCompatibilityInfo): + report = report_tflite_compatiblity + + else: + raise Exception(f"Unable to find appropriate formatter for {data}") - raise Exception(f"Unable to find appropriate formatter for {data}") + return report # type: ignore diff --git a/src/mlia/target/tosa/advice_generation.py b/src/mlia/target/tosa/advice_generation.py index b8b9abf..ad321b2 100644 --- a/src/mlia/target/tosa/advice_generation.py +++ b/src/mlia/target/tosa/advice_generation.py @@ -7,6 +7,10 @@ from mlia.core.advice_generation import advice_category from mlia.core.advice_generation import FactBasedAdviceProducer from mlia.core.common import AdviceCategory from mlia.core.common import DataItem +from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common +from mlia.target.common.reporters import handle_tflite_check_failed_common +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.tosa.data_analysis import ModelIsNotTOSACompatible from mlia.target.tosa.data_analysis import ModelIsTOSACompatible @@ -38,3 +42,19 @@ class TOSAAdviceProducer(FactBasedAdviceProducer): "Please, refer to the operators table for more information." ] ) + + @produce_advice.register + @advice_category(AdviceCategory.COMPATIBILITY) + def handle_model_is_not_tflite_compatible( + self, data_item: ModelIsNotTFLiteCompatible + ) -> None: + """Advice for TensorFlow Lite compatibility.""" + handle_model_is_not_tflite_compatible_common(self, data_item) + + @produce_advice.register + @advice_category(AdviceCategory.COMPATIBILITY) + def handle_tflite_check_failed( + self, _data_item: TFLiteCompatibilityCheckFailed + ) -> None: + """Advice for the failed TensorFlow Lite compatibility checks.""" + handle_tflite_check_failed_common(self, _data_item) diff --git a/src/mlia/target/tosa/data_analysis.py b/src/mlia/target/tosa/data_analysis.py index 7cbd61d..7b31441 100644 --- a/src/mlia/target/tosa/data_analysis.py +++ b/src/mlia/target/tosa/data_analysis.py @@ -1,6 +1,8 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA data analysis module.""" +from __future__ import annotations + from dataclasses import dataclass from functools import singledispatchmethod @@ -8,6 +10,8 @@ from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import analyze_tflite_compatibility_common @dataclass @@ -34,3 +38,8 @@ class TOSADataAnalyzer(FactExtractor): self.add_fact(ModelIsTOSACompatible()) else: self.add_fact(ModelIsNotTOSACompatible()) + + @analyze_data.register + def analyze_tflite_compatibility(self, data_item: TFLiteCompatibilityInfo) -> None: + """Analyze TensorFlow Lite compatibility information.""" + analyze_tflite_compatibility_common(self, data_item) diff --git a/src/mlia/target/tosa/data_collection.py b/src/mlia/target/tosa/data_collection.py index 105c501..19ea6eb 100644 --- a/src/mlia/target/tosa/data_collection.py +++ b/src/mlia/target/tosa/data_collection.py @@ -1,12 +1,17 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA data collection module.""" +from __future__ import annotations + from pathlib import Path from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.data_collection import ContextAwareDataCollector from mlia.nn.tensorflow.config import get_tflite_model +from mlia.nn.tensorflow.tflite_compat import TFLiteChecker +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.utils import is_tflite_model from mlia.utils.logging import log_action @@ -17,8 +22,16 @@ class TOSAOperatorCompatibility(ContextAwareDataCollector): """Init the data collector.""" self.model = model - def collect_data(self) -> TOSACompatibilityInfo: + def collect_data(self) -> TFLiteCompatibilityInfo | TOSACompatibilityInfo | None: """Collect TOSA compatibility information.""" + if not is_tflite_model(self.model): + with log_action("Checking TensorFlow Lite compatibility ..."): + tflite_checker = TFLiteChecker() + tflite_compat = tflite_checker.check_compatibility(self.model) + + if not tflite_compat.compatible: + return tflite_compat + tflite_model = get_tflite_model(self.model, self.context) with log_action("Checking operator compatibility ..."): diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py index 131afa7..26e7226 100644 --- a/src/mlia/target/tosa/handlers.py +++ b/src/mlia/target/tosa/handlers.py @@ -9,6 +9,7 @@ import logging from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.target.tosa.events import TOSAAdvisorEventHandler from mlia.target.tosa.events import TOSAAdvisorStartedEvent from mlia.target.tosa.reporters import tosa_formatters @@ -34,3 +35,6 @@ class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): if isinstance(data_item, TOSACompatibilityInfo): self.reporter.submit(data_item, delay_print=True) + + if isinstance(data_item, TFLiteCompatibilityInfo) and not data_item.compatible: + self.reporter.submit(data_item, delay_print=True) diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py index f54c06b..decae0c 100644 --- a/src/mlia/target/tosa/reporters.py +++ b/src/mlia/target/tosa/reporters.py @@ -22,6 +22,8 @@ from mlia.core.reporting import NestedReport from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.target.common.reporters import report_tflite_compatiblity from mlia.target.tosa.config import TOSAConfiguration from mlia.target.tosa.metadata import TOSAMetadata from mlia.utils.console import style_improvement @@ -163,4 +165,7 @@ def tosa_formatters(data: Any) -> Callable[[Any], Report]: if isinstance(data, TOSACompatibilityInfo): return report_tosa_compatibility + if isinstance(data, TFLiteCompatibilityInfo): + return report_tflite_compatiblity # type: ignore + raise Exception(f"Unable to find appropriate formatter for {data}") diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py index 9596d47..916bdc1 100644 --- a/tests/test_target_cortex_a_advice_generation.py +++ b/tests/test_target_cortex_a_advice_generation.py @@ -13,13 +13,13 @@ from mlia.core.common import AdviceCategory from mlia.core.common import DataItem from mlia.core.context import ExecutionContext from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION +from mlia.target.common.reporters import ModelHasCustomOperators +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.cortex_a.advice_generation import CortexAAdviceProducer from mlia.target.cortex_a.config import CortexAConfiguration -from mlia.target.cortex_a.data_analysis import ModelHasCustomOperators from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible -from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible -from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed VERSION = CortexAConfiguration.load_profile("cortex-a").armnn_tflite_delegate_version BACKEND_INFO = ( diff --git a/tests/test_target_cortex_a_data_analysis.py b/tests/test_target_cortex_a_data_analysis.py index 0a6b490..e033ef9 100644 --- a/tests/test_target_cortex_a_data_analysis.py +++ b/tests/test_target_cortex_a_data_analysis.py @@ -15,13 +15,13 @@ from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION +from mlia.target.common.reporters import ModelHasCustomOperators +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.cortex_a.config import CortexAConfiguration from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer -from mlia.target.cortex_a.data_analysis import ModelHasCustomOperators from mlia.target.cortex_a.data_analysis import ModelIsCortexACompatible from mlia.target.cortex_a.data_analysis import ModelIsNotCortexACompatible -from mlia.target.cortex_a.data_analysis import ModelIsNotTFLiteCompatible -from mlia.target.cortex_a.data_analysis import TFLiteCompatibilityCheckFailed from mlia.target.cortex_a.operators import CortexACompatibilityInfo from mlia.target.cortex_a.operators import Operator diff --git a/tests/test_target_ethos_u_data_analysis.py b/tests/test_target_ethos_u_data_analysis.py index 8e63946..80f0603 100644 --- a/tests/test_target_ethos_u_data_analysis.py +++ b/tests/test_target_ethos_u_data_analysis.py @@ -13,6 +13,13 @@ from mlia.backend.vela.compat import Operators from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.nn.tensorflow.optimizations.select import OptimizationSettings +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus +from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError +from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode +from mlia.target.common.reporters import ModelHasCustomOperators +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.data_analysis import AllOperatorsSupportedOnNPU from mlia.target.ethos_u.data_analysis import EthosUDataAnalyzer @@ -139,6 +146,51 @@ def test_perf_metrics_diff() -> None: ), [], ], + [ + TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE), + [], + ], + [ + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR + ), + [ModelHasCustomOperators()], + ], + [ + TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.UNKNOWN_ERROR), + [TFLiteCompatibilityCheckFailed()], + ], + [ + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR + ), + [ModelIsNotTFLiteCompatible(custom_ops=[], flex_ops=[])], + ], + [ + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, + conversion_errors=[ + TFLiteConversionError( + "error", + TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS, + "custom_op1", + [], + ), + TFLiteConversionError( + "error", + TFLiteConversionErrorCode.NEEDS_FLEX_OPS, + "flex_op1", + [], + ), + ], + ), + [ + ModelIsNotTFLiteCompatible( + custom_ops=["custom_op1"], + flex_ops=["flex_op1"], + ) + ], + ], ], ) def test_ethos_u_data_analyzer( diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py index b8014e4..debeeb2 100644 --- a/tests/test_target_ethos_u_reporters.py +++ b/tests/test_target_ethos_u_reporters.py @@ -3,6 +3,7 @@ """Tests for reports module.""" from __future__ import annotations +from typing import Any from typing import cast import pytest @@ -11,7 +12,10 @@ from mlia.backend.vela.compat import NpuSupported from mlia.backend.vela.compat import Operator from mlia.core.reporting import Report from mlia.core.reporting import Table +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus from mlia.target.ethos_u.config import EthosUConfiguration +from mlia.target.ethos_u.reporters import ethos_u_formatters from mlia.target.ethos_u.reporters import report_operators from mlia.target.ethos_u.reporters import report_target_details from mlia.target.registry import profile @@ -231,3 +235,23 @@ def test_report_target_details( json_dict = report.to_json() assert json_dict == expected_json_dict + + +@pytest.mark.parametrize( + "data", + (TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE),), +) +def test_ethos_u_formatters(data: Any) -> None: + """Test function ethos_u_formatters() with valid input.""" + formatter = ethos_u_formatters(data) + report = formatter(data) + assert isinstance(report, Report) + + +def test_ethos_u_formatters_invalid_data() -> None: + """Test function ethos_u_formatters() with invalid input.""" + with pytest.raises( + Exception, + match=r"^Unable to find appropriate formatter for .*", + ): + ethos_u_formatters(200) diff --git a/tests/test_target_tosa_data_analysis.py b/tests/test_target_tosa_data_analysis.py index 41e977f..23adcc8 100644 --- a/tests/test_target_tosa_data_analysis.py +++ b/tests/test_target_tosa_data_analysis.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA data analysis module.""" from __future__ import annotations @@ -8,6 +8,13 @@ import pytest from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.common import DataItem from mlia.core.data_analysis import Fact +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus +from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError +from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode +from mlia.target.common.reporters import ModelHasCustomOperators +from mlia.target.common.reporters import ModelIsNotTFLiteCompatible +from mlia.target.common.reporters import TFLiteCompatibilityCheckFailed from mlia.target.tosa.data_analysis import ModelIsNotTOSACompatible from mlia.target.tosa.data_analysis import ModelIsTOSACompatible from mlia.target.tosa.data_analysis import TOSADataAnalyzer @@ -24,6 +31,51 @@ from mlia.target.tosa.data_analysis import TOSADataAnalyzer TOSACompatibilityInfo(False, []), [ModelIsNotTOSACompatible()], ], + [ + TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE), + [], + ], + [ + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR + ), + [ModelHasCustomOperators()], + ], + [ + TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.UNKNOWN_ERROR), + [TFLiteCompatibilityCheckFailed()], + ], + [ + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR + ), + [ModelIsNotTFLiteCompatible(custom_ops=[], flex_ops=[])], + ], + [ + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, + conversion_errors=[ + TFLiteConversionError( + "error", + TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS, + "custom_op1", + [], + ), + TFLiteConversionError( + "error", + TFLiteConversionErrorCode.NEEDS_FLEX_OPS, + "flex_op1", + [], + ), + ], + ), + [ + ModelIsNotTFLiteCompatible( + custom_ops=["custom_op1"], + flex_ops=["flex_op1"], + ) + ], + ], ], ) def test_tosa_data_analyzer(input_data: DataItem, expected_facts: list[Fact]) -> None: -- cgit v1.2.1