diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/utils.py')
-rw-r--r-- | src/mlia/nn/tensorflow/utils.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index 1612447..3ac5064 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """Collection of useful functions for optimizations.""" @@ -8,6 +8,7 @@ from pathlib import Path from typing import Any import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 def get_tf_tensor_shape(model: str) -> list: @@ -30,7 +31,7 @@ def get_tf_tensor_shape(model: str) -> list: def save_keras_model( - model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True + model: keras.Model, save_path: str | Path, include_optimizer: bool = True ) -> None: """Save Keras model at provided path.""" model.save(save_path, include_optimizer=include_optimizer) |