From 58a65fee574c00329cf92b387a6d2513dcbf6100 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Mon, 24 Oct 2022 15:08:08 +0100 Subject: MLIA-433 Add TensorFlow Lite compatibility check - Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d --- src/mlia/nn/tensorflow/config.py | 14 ++- src/mlia/nn/tensorflow/tflite_compat.py | 132 ++++++++++++++++++++++++++ src/mlia/nn/tensorflow/utils.py | 159 ++++++++++++++------------------ 3 files changed, 207 insertions(+), 98 deletions(-) create mode 100644 src/mlia/nn/tensorflow/tflite_compat.py (limited to 'src/mlia/nn') diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index 03d1d0f..0c3133a 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -11,12 +11,12 @@ from typing import List import tensorflow as tf from mlia.core.context import Context -from mlia.nn.tensorflow.utils import convert_tf_to_tflite from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model -from mlia.nn.tensorflow.utils import is_tf_model +from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import save_tflite_model +from mlia.utils.logging import log_action logger = logging.getLogger(__name__) @@ -53,10 +53,8 @@ class KerasModel(ModelConfiguration): self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - logger.info("Converting Keras to TensorFlow Lite ...") - - converted_model = convert_to_tflite(self.get_keras_model(), quantized) - logger.info("Done\n") + with log_action("Converting Keras to TensorFlow Lite ..."): + converted_model = convert_to_tflite(self.get_keras_model(), quantized) save_tflite_model(converted_model, tflite_model_path) logger.debug( @@ -95,7 +93,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - converted_model = convert_tf_to_tflite(self.model_path, quantized) + converted_model = convert_to_tflite(self.model_path, quantized) save_tflite_model(converted_model, tflite_model_path) return TFLiteModel(tflite_model_path) @@ -109,7 +107,7 @@ def get_model(model: str | Path) -> ModelConfiguration: if is_keras_model(model): return KerasModel(model) - if is_tf_model(model): + if is_saved_model(model): return TfModel(model) raise Exception( diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py new file mode 100644 index 0000000..960a5c3 --- /dev/null +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Functions for checking TensorFlow Lite compatibility.""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import auto +from enum import Enum +from typing import Any +from typing import cast +from typing import List + +from tensorflow.lite.python import convert +from tensorflow.lite.python.metrics import converter_error_data_pb2 + +from mlia.nn.tensorflow.utils import get_tflite_converter +from mlia.utils.logging import redirect_raw_output + + +logger = logging.getLogger(__name__) + + +class TFLiteConversionErrorCode(Enum): + """TensorFlow Lite conversion error codes.""" + + NEEDS_FLEX_OPS = auto() + NEEDS_CUSTOM_OPS = auto() + UNSUPPORTED_CONTROL_FLOW_V1 = auto() + GPU_NOT_COMPATIBLE = auto() + UNKNOWN = auto() + + +@dataclass +class TFLiteConversionError: + """TensorFlow Lite conversion error details.""" + + message: str + code: TFLiteConversionErrorCode + operator: str + location: list[str] + + +@dataclass +class TFLiteCompatibilityInfo: + """TensorFlow Lite compatibility information.""" + + compatible: bool + conversion_exception: Exception | None = None + conversion_errors: list[TFLiteConversionError] | None = None + + def unsupported_ops_by_code(self, code: TFLiteConversionErrorCode) -> list[str]: + """Filter unsupported operators by error code.""" + if not self.conversion_errors: + return [] + + return [err.operator for err in self.conversion_errors if err.code == code] + + +class TFLiteChecker: + """Class for checking TensorFlow Lite compatibility.""" + + def __init__(self, quantized: bool = False) -> None: + """Init TensorFlow Lite checker.""" + self.quantized = quantized + + def check_compatibility(self, model: Any) -> TFLiteCompatibilityInfo: + """Check TensorFlow Lite compatibility for the provided model.""" + try: + logger.debug("Check TensorFlow Lite compatibility for %s", model) + converter = get_tflite_converter(model, quantized=self.quantized) + + # there is an issue with intercepting TensorFlow output + # not all output could be captured, for now just intercept + # stderr output + with redirect_raw_output( + logging.getLogger("tensorflow"), stdout_level=None + ): + converter.convert() + except convert.ConverterError as err: + return self._process_exception(err) + except Exception as err: # pylint: disable=broad-except + return TFLiteCompatibilityInfo(compatible=False, conversion_exception=err) + else: + return TFLiteCompatibilityInfo(compatible=True) + + def _process_exception( + self, err: convert.ConverterError + ) -> TFLiteCompatibilityInfo: + """Parse error details if possible.""" + conversion_errors = None + if hasattr(err, "errors"): + conversion_errors = [ + TFLiteConversionError( + message=error.error_message.splitlines()[0], + code=self._convert_error_code(error.error_code), + operator=error.operator.name, + location=cast( + List[str], + [loc.name for loc in error.location.call if loc.name] + if hasattr(error, "location") + else [], + ), + ) + for error in err.errors + ] + + return TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + conversion_errors=conversion_errors, + ) + + @staticmethod + def _convert_error_code(code: int) -> TFLiteConversionErrorCode: + """Convert internal error codes.""" + # pylint: disable=no-member + error_data = converter_error_data_pb2.ConverterErrorData + if code == error_data.ERROR_NEEDS_FLEX_OPS: + return TFLiteConversionErrorCode.NEEDS_FLEX_OPS + + if code == error_data.ERROR_NEEDS_CUSTOM_OPS: + return TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS + + if code == error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1: + return TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1 + + if code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE: + return TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE + # pylint: enable=no-member + + return TFLiteConversionErrorCode.UNKNOWN diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index 7970329..287e6ff 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -6,143 +6,122 @@ from __future__ import annotations import logging from pathlib import Path +from typing import Any from typing import Callable +from typing import cast from typing import Iterable import numpy as np import tensorflow as tf -from tensorflow.lite.python.interpreter import Interpreter from mlia.utils.logging import redirect_output -def representative_dataset(model: tf.keras.Model) -> Callable: +def representative_dataset( + input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32 +) -> Callable: """Sample dataset used for quantization.""" - input_shape = model.input_shape + if input_shape[0] != 1: + raise Exception("Only the input batch_size=1 is supported!") def dataset() -> Iterable: - for _ in range(100): - if input_shape[0] != 1: - raise Exception("Only the input batch_size=1 is supported!") + for _ in range(sample_count): data = np.random.rand(*input_shape) - yield [data.astype(np.float32)] + yield [data.astype(input_dtype)] return dataset def get_tf_tensor_shape(model: str) -> list: """Get input shape for the TensorFlow tensor model.""" - # Loading the model loaded = tf.saved_model.load(model) - # The model signature must have 'serving_default' as a key - if "serving_default" not in loaded.signatures.keys(): - raise Exception( - "Unsupported TensorFlow model signature, must have 'serving_default'" - ) - # Get the signature inputs - inputs_tensor_info = loaded.signatures["serving_default"].inputs - dims = [] - # Build a list of all inputs shape sizes - for input_key in inputs_tensor_info: - if input_key.get_shape(): - dims.extend(list(input_key.get_shape())) - return dims - - -def representative_tf_dataset(model: str) -> Callable: - """Sample dataset used for quantization.""" - if not (input_shape := get_tf_tensor_shape(model)): - raise Exception("Unable to get input shape") - def dataset() -> Iterable: - for _ in range(100): - data = np.random.rand(*input_shape) - yield [data.astype(np.float32)] + try: + default_signature_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY + default_signature = loaded.signatures[default_signature_key] + inputs_tensor_info = default_signature.inputs + except KeyError as err: + raise Exception(f"Signature '{default_signature_key}' not found") from err - return dataset + return [ + dim + for input_key in inputs_tensor_info + if (shape := input_key.get_shape()) + for dim in shape + ] -def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter: +def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes: """Convert Keras model to TensorFlow Lite.""" - if not isinstance(model, tf.keras.Model): - raise Exception("Invalid model type") - - converter = tf.lite.TFLiteConverter.from_keras_model(model) - - if quantized: - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset(model) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 + converter = get_tflite_converter(model, quantized) with redirect_output(logging.getLogger("tensorflow")): - tflite_model = converter.convert() - - return tflite_model - - -def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter: - """Convert TensorFlow model to TensorFlow Lite.""" - if not isinstance(model, str): - raise Exception("Invalid model type") - - converter = tf.lite.TFLiteConverter.from_saved_model(model) + return cast(bytes, converter.convert()) - if quantized: - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_tf_dataset(model) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - with redirect_output(logging.getLogger("tensorflow")): - tflite_model = converter.convert() - - return tflite_model - - -def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None: +def save_keras_model( + model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True +) -> None: """Save Keras model at provided path.""" - # Checkpoint: saving the optimizer is necessary. - model.save(save_path, include_optimizer=True) + model.save(save_path, include_optimizer=include_optimizer) -def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None: +def save_tflite_model(tflite_model: bytes, save_path: str | Path) -> None: """Save TensorFlow Lite model at provided path.""" with open(save_path, "wb") as file: - file.write(model) + file.write(tflite_model) def is_tflite_model(model: str | Path) -> bool: - """Check if model type is supported by TensorFlow Lite API. - - TensorFlow Lite model is indicated by the model file extension .tflite - """ + """Check if path contains TensorFlow Lite model.""" model_path = Path(model) + return model_path.suffix == ".tflite" def is_keras_model(model: str | Path) -> bool: - """Check if model type is supported by Keras API. - - Keras model is indicated by: - 1. if it's a directory (meaning saved model), - it should contain keras_metadata.pb file - 2. or if the model file extension is .h5/.hdf5 - """ + """Check if path contains a Keras model.""" model_path = Path(model) if model_path.is_dir(): - return (model_path / "keras_metadata.pb").exists() - return model_path.suffix in (".h5", ".hdf5") + return model_path.joinpath("keras_metadata.pb").exists() + return model_path.suffix in (".h5", ".hdf5") -def is_tf_model(model: str | Path) -> bool: - """Check if model type is supported by TensorFlow API. - TensorFlow model is indicated if its directory (meaning saved model) - doesn't contain keras_metadata.pb file - """ +def is_saved_model(model: str | Path) -> bool: + """Check if path contains SavedModel model.""" model_path = Path(model) + return model_path.is_dir() and not is_keras_model(model) + + +def get_tflite_converter( + model: tf.keras.Model | str | Path, quantized: bool = False +) -> tf.lite.TFLiteConverter: + """Configure TensorFlow Lite converter for the provided model.""" + if isinstance(model, (str, Path)): + # converter's methods accept string as input parameter + model = str(model) + + if isinstance(model, tf.keras.Model): + converter = tf.lite.TFLiteConverter.from_keras_model(model) + input_shape = model.input_shape + elif isinstance(model, str) and is_saved_model(model): + converter = tf.lite.TFLiteConverter.from_saved_model(model) + input_shape = get_tf_tensor_shape(model) + elif isinstance(model, str) and is_keras_model(model): + keras_model = tf.keras.models.load_model(model) + input_shape = keras_model.input_shape + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + else: + raise ValueError(f"Unable to create TensorFlow Lite converter for {model}") + + if quantized: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset(input_shape) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + return converter -- cgit v1.2.1