diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_compat.py')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_compat.py | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py new file mode 100644 index 0000000..960a5c3 --- /dev/null +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Functions for checking TensorFlow Lite compatibility.""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import auto +from enum import Enum +from typing import Any +from typing import cast +from typing import List + +from tensorflow.lite.python import convert +from tensorflow.lite.python.metrics import converter_error_data_pb2 + +from mlia.nn.tensorflow.utils import get_tflite_converter +from mlia.utils.logging import redirect_raw_output + + +logger = logging.getLogger(__name__) + + +class TFLiteConversionErrorCode(Enum): + """TensorFlow Lite conversion error codes.""" + + NEEDS_FLEX_OPS = auto() + NEEDS_CUSTOM_OPS = auto() + UNSUPPORTED_CONTROL_FLOW_V1 = auto() + GPU_NOT_COMPATIBLE = auto() + UNKNOWN = auto() + + +@dataclass +class TFLiteConversionError: + """TensorFlow Lite conversion error details.""" + + message: str + code: TFLiteConversionErrorCode + operator: str + location: list[str] + + +@dataclass +class TFLiteCompatibilityInfo: + """TensorFlow Lite compatibility information.""" + + compatible: bool + conversion_exception: Exception | None = None + conversion_errors: list[TFLiteConversionError] | None = None + + def unsupported_ops_by_code(self, code: TFLiteConversionErrorCode) -> list[str]: + """Filter unsupported operators by error code.""" + if not self.conversion_errors: + return [] + + return [err.operator for err in self.conversion_errors if err.code == code] + + +class TFLiteChecker: + """Class for checking TensorFlow Lite compatibility.""" + + def __init__(self, quantized: bool = False) -> None: + """Init TensorFlow Lite checker.""" + self.quantized = quantized + + def check_compatibility(self, model: Any) -> TFLiteCompatibilityInfo: + """Check TensorFlow Lite compatibility for the provided model.""" + try: + logger.debug("Check TensorFlow Lite compatibility for %s", model) + converter = get_tflite_converter(model, quantized=self.quantized) + + # there is an issue with intercepting TensorFlow output + # not all output could be captured, for now just intercept + # stderr output + with redirect_raw_output( + logging.getLogger("tensorflow"), stdout_level=None + ): + converter.convert() + except convert.ConverterError as err: + return self._process_exception(err) + except Exception as err: # pylint: disable=broad-except + return TFLiteCompatibilityInfo(compatible=False, conversion_exception=err) + else: + return TFLiteCompatibilityInfo(compatible=True) + + def _process_exception( + self, err: convert.ConverterError + ) -> TFLiteCompatibilityInfo: + """Parse error details if possible.""" + conversion_errors = None + if hasattr(err, "errors"): + conversion_errors = [ + TFLiteConversionError( + message=error.error_message.splitlines()[0], + code=self._convert_error_code(error.error_code), + operator=error.operator.name, + location=cast( + List[str], + [loc.name for loc in error.location.call if loc.name] + if hasattr(error, "location") + else [], + ), + ) + for error in err.errors + ] + + return TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + conversion_errors=conversion_errors, + ) + + @staticmethod + def _convert_error_code(code: int) -> TFLiteConversionErrorCode: + """Convert internal error codes.""" + # pylint: disable=no-member + error_data = converter_error_data_pb2.ConverterErrorData + if code == error_data.ERROR_NEEDS_FLEX_OPS: + return TFLiteConversionErrorCode.NEEDS_FLEX_OPS + + if code == error_data.ERROR_NEEDS_CUSTOM_OPS: + return TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS + + if code == error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1: + return TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1 + + if code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE: + return TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE + # pylint: enable=no-member + + return TFLiteConversionErrorCode.UNKNOWN |