aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/tflite_compat.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_compat.py')
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py132
1 files changed, 132 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
new file mode 100644
index 0000000..960a5c3
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -0,0 +1,132 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Functions for checking TensorFlow Lite compatibility."""
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass
+from enum import auto
+from enum import Enum
+from typing import Any
+from typing import cast
+from typing import List
+
+from tensorflow.lite.python import convert
+from tensorflow.lite.python.metrics import converter_error_data_pb2
+
+from mlia.nn.tensorflow.utils import get_tflite_converter
+from mlia.utils.logging import redirect_raw_output
+
+
+logger = logging.getLogger(__name__)
+
+
+class TFLiteConversionErrorCode(Enum):
+ """TensorFlow Lite conversion error codes."""
+
+ NEEDS_FLEX_OPS = auto()
+ NEEDS_CUSTOM_OPS = auto()
+ UNSUPPORTED_CONTROL_FLOW_V1 = auto()
+ GPU_NOT_COMPATIBLE = auto()
+ UNKNOWN = auto()
+
+
+@dataclass
+class TFLiteConversionError:
+ """TensorFlow Lite conversion error details."""
+
+ message: str
+ code: TFLiteConversionErrorCode
+ operator: str
+ location: list[str]
+
+
+@dataclass
+class TFLiteCompatibilityInfo:
+ """TensorFlow Lite compatibility information."""
+
+ compatible: bool
+ conversion_exception: Exception | None = None
+ conversion_errors: list[TFLiteConversionError] | None = None
+
+ def unsupported_ops_by_code(self, code: TFLiteConversionErrorCode) -> list[str]:
+ """Filter unsupported operators by error code."""
+ if not self.conversion_errors:
+ return []
+
+ return [err.operator for err in self.conversion_errors if err.code == code]
+
+
+class TFLiteChecker:
+ """Class for checking TensorFlow Lite compatibility."""
+
+ def __init__(self, quantized: bool = False) -> None:
+ """Init TensorFlow Lite checker."""
+ self.quantized = quantized
+
+ def check_compatibility(self, model: Any) -> TFLiteCompatibilityInfo:
+ """Check TensorFlow Lite compatibility for the provided model."""
+ try:
+ logger.debug("Check TensorFlow Lite compatibility for %s", model)
+ converter = get_tflite_converter(model, quantized=self.quantized)
+
+ # there is an issue with intercepting TensorFlow output
+ # not all output could be captured, for now just intercept
+ # stderr output
+ with redirect_raw_output(
+ logging.getLogger("tensorflow"), stdout_level=None
+ ):
+ converter.convert()
+ except convert.ConverterError as err:
+ return self._process_exception(err)
+ except Exception as err: # pylint: disable=broad-except
+ return TFLiteCompatibilityInfo(compatible=False, conversion_exception=err)
+ else:
+ return TFLiteCompatibilityInfo(compatible=True)
+
+ def _process_exception(
+ self, err: convert.ConverterError
+ ) -> TFLiteCompatibilityInfo:
+ """Parse error details if possible."""
+ conversion_errors = None
+ if hasattr(err, "errors"):
+ conversion_errors = [
+ TFLiteConversionError(
+ message=error.error_message.splitlines()[0],
+ code=self._convert_error_code(error.error_code),
+ operator=error.operator.name,
+ location=cast(
+ List[str],
+ [loc.name for loc in error.location.call if loc.name]
+ if hasattr(error, "location")
+ else [],
+ ),
+ )
+ for error in err.errors
+ ]
+
+ return TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=conversion_errors,
+ )
+
+ @staticmethod
+ def _convert_error_code(code: int) -> TFLiteConversionErrorCode:
+ """Convert internal error codes."""
+ # pylint: disable=no-member
+ error_data = converter_error_data_pb2.ConverterErrorData
+ if code == error_data.ERROR_NEEDS_FLEX_OPS:
+ return TFLiteConversionErrorCode.NEEDS_FLEX_OPS
+
+ if code == error_data.ERROR_NEEDS_CUSTOM_OPS:
+ return TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS
+
+ if code == error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1:
+ return TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1
+
+ if code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE:
+ return TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE
+ # pylint: enable=no-member
+
+ return TFLiteConversionErrorCode.UNKNOWN