diff options
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r-- | tests/test_nn_tensorflow_tflite_compat.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py index 4ca387c..ee60ff7 100644 --- a/tests/test_nn_tensorflow_tflite_compat.py +++ b/tests/test_nn_tensorflow_tflite_compat.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for tflite_compat module.""" from __future__ import annotations @@ -6,7 +6,7 @@ from __future__ import annotations from unittest.mock import MagicMock import pytest -import tensorflow as tf +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow.lite.python import convert from mlia.nn.tensorflow.tflite_compat import converter_error_data_pb2 @@ -19,11 +19,11 @@ from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode def test_not_fully_compatible_model_flex_ops() -> None: """Test models that requires TF_SELECT_OPS.""" - model = tf.keras.models.Sequential( + model = keras.models.Sequential( [ - tf.keras.layers.Dense(units=1, input_shape=[1], batch_size=1), - tf.keras.layers.Dense(units=16, activation="softsign"), - tf.keras.layers.Dense(units=1), + keras.layers.Dense(units=1, input_shape=[1], batch_size=1), + keras.layers.Dense(units=16, activation="softsign"), + keras.layers.Dense(units=1), ] ) |