From 54eec806272b7574a0757c77a913a369a9ecdc70 Mon Sep 17 00:00:00 2001 From: Gergely Nagy Date: Tue, 21 Nov 2023 12:29:38 +0000 Subject: MLIA-835 Invalid JSON output TFLiteConverter was producing log messages in the output that was not possible to capture and redirect to logging. The solution/workaround is to run it as a subprocess. This change required some refactoring around existing invocations of the converter. Change-Id: I394bd0d49d36e6686cfcb9d658e4aad05326cb87 Signed-off-by: Gergely Nagy --- src/mlia/nn/tensorflow/utils.py | 59 ----------------------------------------- 1 file changed, 59 deletions(-) (limited to 'src/mlia/nn/tensorflow/utils.py') diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index b8d45c6..1612447 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -4,31 +4,11 @@ """Collection of useful functions for optimizations.""" from __future__ import annotations -import logging from pathlib import Path from typing import Any -from typing import Callable -from typing import cast -from typing import Iterable -import numpy as np import tensorflow as tf -from mlia.utils.logging import redirect_output - - -def representative_dataset( - input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32 -) -> Callable: - """Sample dataset used for quantization.""" - - def dataset() -> Iterable: - for _ in range(sample_count): - data = np.random.rand(1, *input_shape[1:]) - yield [data.astype(input_dtype)] - - return dataset - def get_tf_tensor_shape(model: str) -> list: """Get input shape for the TensorFlow tensor model.""" @@ -49,14 +29,6 @@ def get_tf_tensor_shape(model: str) -> list: ] -def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes: - """Convert Keras model to TensorFlow Lite.""" - converter = get_tflite_converter(model, quantized) - - with redirect_output(logging.getLogger("tensorflow")): - return cast(bytes, converter.convert()) - - def save_keras_model( model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True ) -> None: @@ -94,37 +66,6 @@ def is_saved_model(model: str | Path) -> bool: return model_path.is_dir() and not is_keras_model(model) -def get_tflite_converter( - model: tf.keras.Model | str | Path, quantized: bool = False -) -> tf.lite.TFLiteConverter: - """Configure TensorFlow Lite converter for the provided model.""" - if isinstance(model, (str, Path)): - # converter's methods accept string as input parameter - model = str(model) - - if isinstance(model, tf.keras.Model): - converter = tf.lite.TFLiteConverter.from_keras_model(model) - input_shape = model.input_shape - elif isinstance(model, str) and is_saved_model(model): - converter = tf.lite.TFLiteConverter.from_saved_model(model) - input_shape = get_tf_tensor_shape(model) - elif isinstance(model, str) and is_keras_model(model): - keras_model = tf.keras.models.load_model(model) - input_shape = keras_model.input_shape - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) - else: - raise ValueError(f"Unable to create TensorFlow Lite converter for {model}") - - if quantized: - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset(input_shape) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - - return converter - - def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]: """Get type map from tflite model.""" model_type_map: dict[str, Any] = {} -- cgit v1.2.1