diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-20 08:13:39 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-28 07:17:32 +0000 |
commit | f3f3ab451968350b8f6df2de7c60b2c2b9320b59 (patch) | |
tree | 05d56c8e41de9b32f8054019a21b78628151310d /src/mlia/nn/tensorflow/tflite_convert.py | |
parent | 5f063ae1cfbfa2568d2858af0a0ccaf192bb1e8d (diff) | |
download | mlia-f3f3ab451968350b8f6df2de7c60b2c2b9320b59.tar.gz |
feat: Update Vela version
Updates Vela Version to 3.11.0 and TensorFlow version to 2.15.1
Required keras import to change:
from keras.api._v2 import keras needed instead of calling tf.keras
Subsequently tf.keras.X needed to change to keras.X
Resolves: MLIA-1107
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I53bcaa9cdad58b0e6c311c8c6490393d33cb18bc
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_convert.py')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_convert.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_convert.py b/src/mlia/nn/tensorflow/tflite_convert.py index d3a833a..29839d6 100644 --- a/src/mlia/nn/tensorflow/tflite_convert.py +++ b/src/mlia/nn/tensorflow/tflite_convert.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Support module to call TFLiteConverter.""" from __future__ import annotations @@ -14,6 +14,7 @@ from typing import Iterable import numpy as np import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.tensorflow.utils import get_tf_tensor_shape from mlia.nn.tensorflow.utils import is_keras_model @@ -23,6 +24,7 @@ from mlia.utils.logging import redirect_output from mlia.utils.proc import Command from mlia.utils.proc import command_output + logger = logging.getLogger(__name__) @@ -40,21 +42,21 @@ def representative_dataset( def get_tflite_converter( - model: tf.keras.Model | str | Path, quantized: bool = False + model: keras.Model | str | Path, quantized: bool = False ) -> tf.lite.TFLiteConverter: """Configure TensorFlow Lite converter for the provided model.""" if isinstance(model, (str, Path)): # converter's methods accept string as input parameter model = str(model) - if isinstance(model, tf.keras.Model): + if isinstance(model, keras.Model): converter = tf.lite.TFLiteConverter.from_keras_model(model) input_shape = model.input_shape elif isinstance(model, str) and is_saved_model(model): converter = tf.lite.TFLiteConverter.from_saved_model(model) input_shape = get_tf_tensor_shape(model) elif isinstance(model, str) and is_keras_model(model): - keras_model = tf.keras.models.load_model(model) + keras_model = keras.models.load_model(model) input_shape = keras_model.input_shape converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) else: @@ -70,9 +72,7 @@ def get_tflite_converter( return converter -def convert_to_tflite_bytes( - model: tf.keras.Model | str, quantized: bool = False -) -> bytes: +def convert_to_tflite_bytes(model: keras.Model | str, quantized: bool = False) -> bytes: """Convert Keras model to TensorFlow Lite.""" converter = get_tflite_converter(model, quantized) @@ -83,7 +83,7 @@ def convert_to_tflite_bytes( def _convert_to_tflite( - model: tf.keras.Model | str, + model: keras.Model | str, quantized: bool = False, output_path: Path | None = None, ) -> bytes: @@ -97,7 +97,7 @@ def _convert_to_tflite( def convert_to_tflite( - model: tf.keras.Model | str, + model: keras.Model | str, quantized: bool = False, output_path: Path | None = None, input_path: Path | None = None, |