aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/utils.py')
-rw-r--r--src/mlia/nn/tensorflow/utils.py59
1 files changed, 0 insertions, 59 deletions
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index b8d45c6..1612447 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -4,31 +4,11 @@
"""Collection of useful functions for optimizations."""
from __future__ import annotations
-import logging
from pathlib import Path
from typing import Any
-from typing import Callable
-from typing import cast
-from typing import Iterable
-import numpy as np
import tensorflow as tf
-from mlia.utils.logging import redirect_output
-
-
-def representative_dataset(
- input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
-) -> Callable:
- """Sample dataset used for quantization."""
-
- def dataset() -> Iterable:
- for _ in range(sample_count):
- data = np.random.rand(1, *input_shape[1:])
- yield [data.astype(input_dtype)]
-
- return dataset
-
def get_tf_tensor_shape(model: str) -> list:
"""Get input shape for the TensorFlow tensor model."""
@@ -49,14 +29,6 @@ def get_tf_tensor_shape(model: str) -> list:
]
-def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes:
- """Convert Keras model to TensorFlow Lite."""
- converter = get_tflite_converter(model, quantized)
-
- with redirect_output(logging.getLogger("tensorflow")):
- return cast(bytes, converter.convert())
-
-
def save_keras_model(
model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
) -> None:
@@ -94,37 +66,6 @@ def is_saved_model(model: str | Path) -> bool:
return model_path.is_dir() and not is_keras_model(model)
-def get_tflite_converter(
- model: tf.keras.Model | str | Path, quantized: bool = False
-) -> tf.lite.TFLiteConverter:
- """Configure TensorFlow Lite converter for the provided model."""
- if isinstance(model, (str, Path)):
- # converter's methods accept string as input parameter
- model = str(model)
-
- if isinstance(model, tf.keras.Model):
- converter = tf.lite.TFLiteConverter.from_keras_model(model)
- input_shape = model.input_shape
- elif isinstance(model, str) and is_saved_model(model):
- converter = tf.lite.TFLiteConverter.from_saved_model(model)
- input_shape = get_tf_tensor_shape(model)
- elif isinstance(model, str) and is_keras_model(model):
- keras_model = tf.keras.models.load_model(model)
- input_shape = keras_model.input_shape
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
- else:
- raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
-
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset(input_shape)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
-
- return converter
-
-
def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
"""Get type map from tflite model."""
model_type_map: dict[str, Any] = {}