aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10 /src/mlia/nn
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/tensorflow/config.py14
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py132
-rw-r--r--src/mlia/nn/tensorflow/utils.py159
3 files changed, 207 insertions, 98 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 03d1d0f..0c3133a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -11,12 +11,12 @@ from typing import List
import tensorflow as tf
from mlia.core.context import Context
-from mlia.nn.tensorflow.utils import convert_tf_to_tflite
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
-from mlia.nn.tensorflow.utils import is_tf_model
+from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.utils.logging import log_action
logger = logging.getLogger(__name__)
@@ -53,10 +53,8 @@ class KerasModel(ModelConfiguration):
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- logger.info("Converting Keras to TensorFlow Lite ...")
-
- converted_model = convert_to_tflite(self.get_keras_model(), quantized)
- logger.info("Done\n")
+ with log_action("Converting Keras to TensorFlow Lite ..."):
+ converted_model = convert_to_tflite(self.get_keras_model(), quantized)
save_tflite_model(converted_model, tflite_model_path)
logger.debug(
@@ -95,7 +93,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- converted_model = convert_tf_to_tflite(self.model_path, quantized)
+ converted_model = convert_to_tflite(self.model_path, quantized)
save_tflite_model(converted_model, tflite_model_path)
return TFLiteModel(tflite_model_path)
@@ -109,7 +107,7 @@ def get_model(model: str | Path) -> ModelConfiguration:
if is_keras_model(model):
return KerasModel(model)
- if is_tf_model(model):
+ if is_saved_model(model):
return TfModel(model)
raise Exception(
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
new file mode 100644
index 0000000..960a5c3
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -0,0 +1,132 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Functions for checking TensorFlow Lite compatibility."""
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass
+from enum import auto
+from enum import Enum
+from typing import Any
+from typing import cast
+from typing import List
+
+from tensorflow.lite.python import convert
+from tensorflow.lite.python.metrics import converter_error_data_pb2
+
+from mlia.nn.tensorflow.utils import get_tflite_converter
+from mlia.utils.logging import redirect_raw_output
+
+
+logger = logging.getLogger(__name__)
+
+
+class TFLiteConversionErrorCode(Enum):
+ """TensorFlow Lite conversion error codes."""
+
+ NEEDS_FLEX_OPS = auto()
+ NEEDS_CUSTOM_OPS = auto()
+ UNSUPPORTED_CONTROL_FLOW_V1 = auto()
+ GPU_NOT_COMPATIBLE = auto()
+ UNKNOWN = auto()
+
+
+@dataclass
+class TFLiteConversionError:
+ """TensorFlow Lite conversion error details."""
+
+ message: str
+ code: TFLiteConversionErrorCode
+ operator: str
+ location: list[str]
+
+
+@dataclass
+class TFLiteCompatibilityInfo:
+ """TensorFlow Lite compatibility information."""
+
+ compatible: bool
+ conversion_exception: Exception | None = None
+ conversion_errors: list[TFLiteConversionError] | None = None
+
+ def unsupported_ops_by_code(self, code: TFLiteConversionErrorCode) -> list[str]:
+ """Filter unsupported operators by error code."""
+ if not self.conversion_errors:
+ return []
+
+ return [err.operator for err in self.conversion_errors if err.code == code]
+
+
+class TFLiteChecker:
+ """Class for checking TensorFlow Lite compatibility."""
+
+ def __init__(self, quantized: bool = False) -> None:
+ """Init TensorFlow Lite checker."""
+ self.quantized = quantized
+
+ def check_compatibility(self, model: Any) -> TFLiteCompatibilityInfo:
+ """Check TensorFlow Lite compatibility for the provided model."""
+ try:
+ logger.debug("Check TensorFlow Lite compatibility for %s", model)
+ converter = get_tflite_converter(model, quantized=self.quantized)
+
+ # there is an issue with intercepting TensorFlow output
+ # not all output could be captured, for now just intercept
+ # stderr output
+ with redirect_raw_output(
+ logging.getLogger("tensorflow"), stdout_level=None
+ ):
+ converter.convert()
+ except convert.ConverterError as err:
+ return self._process_exception(err)
+ except Exception as err: # pylint: disable=broad-except
+ return TFLiteCompatibilityInfo(compatible=False, conversion_exception=err)
+ else:
+ return TFLiteCompatibilityInfo(compatible=True)
+
+ def _process_exception(
+ self, err: convert.ConverterError
+ ) -> TFLiteCompatibilityInfo:
+ """Parse error details if possible."""
+ conversion_errors = None
+ if hasattr(err, "errors"):
+ conversion_errors = [
+ TFLiteConversionError(
+ message=error.error_message.splitlines()[0],
+ code=self._convert_error_code(error.error_code),
+ operator=error.operator.name,
+ location=cast(
+ List[str],
+ [loc.name for loc in error.location.call if loc.name]
+ if hasattr(error, "location")
+ else [],
+ ),
+ )
+ for error in err.errors
+ ]
+
+ return TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=conversion_errors,
+ )
+
+ @staticmethod
+ def _convert_error_code(code: int) -> TFLiteConversionErrorCode:
+ """Convert internal error codes."""
+ # pylint: disable=no-member
+ error_data = converter_error_data_pb2.ConverterErrorData
+ if code == error_data.ERROR_NEEDS_FLEX_OPS:
+ return TFLiteConversionErrorCode.NEEDS_FLEX_OPS
+
+ if code == error_data.ERROR_NEEDS_CUSTOM_OPS:
+ return TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS
+
+ if code == error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1:
+ return TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1
+
+ if code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE:
+ return TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE
+ # pylint: enable=no-member
+
+ return TFLiteConversionErrorCode.UNKNOWN
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 7970329..287e6ff 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -6,143 +6,122 @@ from __future__ import annotations
import logging
from pathlib import Path
+from typing import Any
from typing import Callable
+from typing import cast
from typing import Iterable
import numpy as np
import tensorflow as tf
-from tensorflow.lite.python.interpreter import Interpreter
from mlia.utils.logging import redirect_output
-def representative_dataset(model: tf.keras.Model) -> Callable:
+def representative_dataset(
+ input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
+) -> Callable:
"""Sample dataset used for quantization."""
- input_shape = model.input_shape
+ if input_shape[0] != 1:
+ raise Exception("Only the input batch_size=1 is supported!")
def dataset() -> Iterable:
- for _ in range(100):
- if input_shape[0] != 1:
- raise Exception("Only the input batch_size=1 is supported!")
+ for _ in range(sample_count):
data = np.random.rand(*input_shape)
- yield [data.astype(np.float32)]
+ yield [data.astype(input_dtype)]
return dataset
def get_tf_tensor_shape(model: str) -> list:
"""Get input shape for the TensorFlow tensor model."""
- # Loading the model
loaded = tf.saved_model.load(model)
- # The model signature must have 'serving_default' as a key
- if "serving_default" not in loaded.signatures.keys():
- raise Exception(
- "Unsupported TensorFlow model signature, must have 'serving_default'"
- )
- # Get the signature inputs
- inputs_tensor_info = loaded.signatures["serving_default"].inputs
- dims = []
- # Build a list of all inputs shape sizes
- for input_key in inputs_tensor_info:
- if input_key.get_shape():
- dims.extend(list(input_key.get_shape()))
- return dims
-
-
-def representative_tf_dataset(model: str) -> Callable:
- """Sample dataset used for quantization."""
- if not (input_shape := get_tf_tensor_shape(model)):
- raise Exception("Unable to get input shape")
- def dataset() -> Iterable:
- for _ in range(100):
- data = np.random.rand(*input_shape)
- yield [data.astype(np.float32)]
+ try:
+ default_signature_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ default_signature = loaded.signatures[default_signature_key]
+ inputs_tensor_info = default_signature.inputs
+ except KeyError as err:
+ raise Exception(f"Signature '{default_signature_key}' not found") from err
- return dataset
+ return [
+ dim
+ for input_key in inputs_tensor_info
+ if (shape := input_key.get_shape())
+ for dim in shape
+ ]
-def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
+def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes:
"""Convert Keras model to TensorFlow Lite."""
- if not isinstance(model, tf.keras.Model):
- raise Exception("Invalid model type")
-
- converter = tf.lite.TFLiteConverter.from_keras_model(model)
-
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset(model)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
+ converter = get_tflite_converter(model, quantized)
with redirect_output(logging.getLogger("tensorflow")):
- tflite_model = converter.convert()
-
- return tflite_model
-
-
-def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
- """Convert TensorFlow model to TensorFlow Lite."""
- if not isinstance(model, str):
- raise Exception("Invalid model type")
-
- converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ return cast(bytes, converter.convert())
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_tf_dataset(model)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- with redirect_output(logging.getLogger("tensorflow")):
- tflite_model = converter.convert()
-
- return tflite_model
-
-
-def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None:
+def save_keras_model(
+ model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
+) -> None:
"""Save Keras model at provided path."""
- # Checkpoint: saving the optimizer is necessary.
- model.save(save_path, include_optimizer=True)
+ model.save(save_path, include_optimizer=include_optimizer)
-def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None:
+def save_tflite_model(tflite_model: bytes, save_path: str | Path) -> None:
"""Save TensorFlow Lite model at provided path."""
with open(save_path, "wb") as file:
- file.write(model)
+ file.write(tflite_model)
def is_tflite_model(model: str | Path) -> bool:
- """Check if model type is supported by TensorFlow Lite API.
-
- TensorFlow Lite model is indicated by the model file extension .tflite
- """
+ """Check if path contains TensorFlow Lite model."""
model_path = Path(model)
+
return model_path.suffix == ".tflite"
def is_keras_model(model: str | Path) -> bool:
- """Check if model type is supported by Keras API.
-
- Keras model is indicated by:
- 1. if it's a directory (meaning saved model),
- it should contain keras_metadata.pb file
- 2. or if the model file extension is .h5/.hdf5
- """
+ """Check if path contains a Keras model."""
model_path = Path(model)
if model_path.is_dir():
- return (model_path / "keras_metadata.pb").exists()
- return model_path.suffix in (".h5", ".hdf5")
+ return model_path.joinpath("keras_metadata.pb").exists()
+ return model_path.suffix in (".h5", ".hdf5")
-def is_tf_model(model: str | Path) -> bool:
- """Check if model type is supported by TensorFlow API.
- TensorFlow model is indicated if its directory (meaning saved model)
- doesn't contain keras_metadata.pb file
- """
+def is_saved_model(model: str | Path) -> bool:
+ """Check if path contains SavedModel model."""
model_path = Path(model)
+
return model_path.is_dir() and not is_keras_model(model)
+
+
+def get_tflite_converter(
+ model: tf.keras.Model | str | Path, quantized: bool = False
+) -> tf.lite.TFLiteConverter:
+ """Configure TensorFlow Lite converter for the provided model."""
+ if isinstance(model, (str, Path)):
+ # converter's methods accept string as input parameter
+ model = str(model)
+
+ if isinstance(model, tf.keras.Model):
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ input_shape = model.input_shape
+ elif isinstance(model, str) and is_saved_model(model):
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ input_shape = get_tf_tensor_shape(model)
+ elif isinstance(model, str) and is_keras_model(model):
+ keras_model = tf.keras.models.load_model(model)
+ input_shape = keras_model.input_shape
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ else:
+ raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(input_shape)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ return converter