diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_convert.py')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_convert.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_convert.py b/src/mlia/nn/tensorflow/tflite_convert.py index d3a833a..29839d6 100644 --- a/src/mlia/nn/tensorflow/tflite_convert.py +++ b/src/mlia/nn/tensorflow/tflite_convert.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Support module to call TFLiteConverter.""" from __future__ import annotations @@ -14,6 +14,7 @@ from typing import Iterable import numpy as np import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.tensorflow.utils import get_tf_tensor_shape from mlia.nn.tensorflow.utils import is_keras_model @@ -23,6 +24,7 @@ from mlia.utils.logging import redirect_output from mlia.utils.proc import Command from mlia.utils.proc import command_output + logger = logging.getLogger(__name__) @@ -40,21 +42,21 @@ def representative_dataset( def get_tflite_converter( - model: tf.keras.Model | str | Path, quantized: bool = False + model: keras.Model | str | Path, quantized: bool = False ) -> tf.lite.TFLiteConverter: """Configure TensorFlow Lite converter for the provided model.""" if isinstance(model, (str, Path)): # converter's methods accept string as input parameter model = str(model) - if isinstance(model, tf.keras.Model): + if isinstance(model, keras.Model): converter = tf.lite.TFLiteConverter.from_keras_model(model) input_shape = model.input_shape elif isinstance(model, str) and is_saved_model(model): converter = tf.lite.TFLiteConverter.from_saved_model(model) input_shape = get_tf_tensor_shape(model) elif isinstance(model, str) and is_keras_model(model): - keras_model = tf.keras.models.load_model(model) + keras_model = keras.models.load_model(model) input_shape = keras_model.input_shape converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) else: @@ -70,9 +72,7 @@ def get_tflite_converter( return converter -def convert_to_tflite_bytes( - model: tf.keras.Model | str, quantized: bool = False -) -> bytes: +def convert_to_tflite_bytes(model: keras.Model | str, quantized: bool = False) -> bytes: """Convert Keras model to TensorFlow Lite.""" converter = get_tflite_converter(model, quantized) @@ -83,7 +83,7 @@ def convert_to_tflite_bytes( def _convert_to_tflite( - model: tf.keras.Model | str, + model: keras.Model | str, quantized: bool = False, output_path: Path | None = None, ) -> bytes: @@ -97,7 +97,7 @@ def _convert_to_tflite( def convert_to_tflite( - model: tf.keras.Model | str, + model: keras.Model | str, quantized: bool = False, output_path: Path | None = None, input_path: Path | None = None, |