aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/utils.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10 /src/mlia/nn/tensorflow/utils.py
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'src/mlia/nn/tensorflow/utils.py')
-rw-r--r--src/mlia/nn/tensorflow/utils.py159
1 files changed, 69 insertions, 90 deletions
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 7970329..287e6ff 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -6,143 +6,122 @@ from __future__ import annotations
import logging
from pathlib import Path
+from typing import Any
from typing import Callable
+from typing import cast
from typing import Iterable
import numpy as np
import tensorflow as tf
-from tensorflow.lite.python.interpreter import Interpreter
from mlia.utils.logging import redirect_output
-def representative_dataset(model: tf.keras.Model) -> Callable:
+def representative_dataset(
+ input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
+) -> Callable:
"""Sample dataset used for quantization."""
- input_shape = model.input_shape
+ if input_shape[0] != 1:
+ raise Exception("Only the input batch_size=1 is supported!")
def dataset() -> Iterable:
- for _ in range(100):
- if input_shape[0] != 1:
- raise Exception("Only the input batch_size=1 is supported!")
+ for _ in range(sample_count):
data = np.random.rand(*input_shape)
- yield [data.astype(np.float32)]
+ yield [data.astype(input_dtype)]
return dataset
def get_tf_tensor_shape(model: str) -> list:
"""Get input shape for the TensorFlow tensor model."""
- # Loading the model
loaded = tf.saved_model.load(model)
- # The model signature must have 'serving_default' as a key
- if "serving_default" not in loaded.signatures.keys():
- raise Exception(
- "Unsupported TensorFlow model signature, must have 'serving_default'"
- )
- # Get the signature inputs
- inputs_tensor_info = loaded.signatures["serving_default"].inputs
- dims = []
- # Build a list of all inputs shape sizes
- for input_key in inputs_tensor_info:
- if input_key.get_shape():
- dims.extend(list(input_key.get_shape()))
- return dims
-
-
-def representative_tf_dataset(model: str) -> Callable:
- """Sample dataset used for quantization."""
- if not (input_shape := get_tf_tensor_shape(model)):
- raise Exception("Unable to get input shape")
- def dataset() -> Iterable:
- for _ in range(100):
- data = np.random.rand(*input_shape)
- yield [data.astype(np.float32)]
+ try:
+ default_signature_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ default_signature = loaded.signatures[default_signature_key]
+ inputs_tensor_info = default_signature.inputs
+ except KeyError as err:
+ raise Exception(f"Signature '{default_signature_key}' not found") from err
- return dataset
+ return [
+ dim
+ for input_key in inputs_tensor_info
+ if (shape := input_key.get_shape())
+ for dim in shape
+ ]
-def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
+def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes:
"""Convert Keras model to TensorFlow Lite."""
- if not isinstance(model, tf.keras.Model):
- raise Exception("Invalid model type")
-
- converter = tf.lite.TFLiteConverter.from_keras_model(model)
-
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset(model)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
+ converter = get_tflite_converter(model, quantized)
with redirect_output(logging.getLogger("tensorflow")):
- tflite_model = converter.convert()
-
- return tflite_model
-
-
-def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
- """Convert TensorFlow model to TensorFlow Lite."""
- if not isinstance(model, str):
- raise Exception("Invalid model type")
-
- converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ return cast(bytes, converter.convert())
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_tf_dataset(model)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- with redirect_output(logging.getLogger("tensorflow")):
- tflite_model = converter.convert()
-
- return tflite_model
-
-
-def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None:
+def save_keras_model(
+ model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
+) -> None:
"""Save Keras model at provided path."""
- # Checkpoint: saving the optimizer is necessary.
- model.save(save_path, include_optimizer=True)
+ model.save(save_path, include_optimizer=include_optimizer)
-def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None:
+def save_tflite_model(tflite_model: bytes, save_path: str | Path) -> None:
"""Save TensorFlow Lite model at provided path."""
with open(save_path, "wb") as file:
- file.write(model)
+ file.write(tflite_model)
def is_tflite_model(model: str | Path) -> bool:
- """Check if model type is supported by TensorFlow Lite API.
-
- TensorFlow Lite model is indicated by the model file extension .tflite
- """
+ """Check if path contains TensorFlow Lite model."""
model_path = Path(model)
+
return model_path.suffix == ".tflite"
def is_keras_model(model: str | Path) -> bool:
- """Check if model type is supported by Keras API.
-
- Keras model is indicated by:
- 1. if it's a directory (meaning saved model),
- it should contain keras_metadata.pb file
- 2. or if the model file extension is .h5/.hdf5
- """
+ """Check if path contains a Keras model."""
model_path = Path(model)
if model_path.is_dir():
- return (model_path / "keras_metadata.pb").exists()
- return model_path.suffix in (".h5", ".hdf5")
+ return model_path.joinpath("keras_metadata.pb").exists()
+ return model_path.suffix in (".h5", ".hdf5")
-def is_tf_model(model: str | Path) -> bool:
- """Check if model type is supported by TensorFlow API.
- TensorFlow model is indicated if its directory (meaning saved model)
- doesn't contain keras_metadata.pb file
- """
+def is_saved_model(model: str | Path) -> bool:
+ """Check if path contains SavedModel model."""
model_path = Path(model)
+
return model_path.is_dir() and not is_keras_model(model)
+
+
+def get_tflite_converter(
+ model: tf.keras.Model | str | Path, quantized: bool = False
+) -> tf.lite.TFLiteConverter:
+ """Configure TensorFlow Lite converter for the provided model."""
+ if isinstance(model, (str, Path)):
+ # converter's methods accept string as input parameter
+ model = str(model)
+
+ if isinstance(model, tf.keras.Model):
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ input_shape = model.input_shape
+ elif isinstance(model, str) and is_saved_model(model):
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ input_shape = get_tf_tensor_shape(model)
+ elif isinstance(model, str) and is_keras_model(model):
+ keras_model = tf.keras.models.load_model(model)
+ input_shape = keras_model.input_shape
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ else:
+ raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(input_shape)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ return converter