aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_tflite_compat.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r--tests/test_nn_tensorflow_tflite_compat.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py
index 4ca387c..ee60ff7 100644
--- a/tests/test_nn_tensorflow_tflite_compat.py
+++ b/tests/test_nn_tensorflow_tflite_compat.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for tflite_compat module."""
from __future__ import annotations
@@ -6,7 +6,7 @@ from __future__ import annotations
from unittest.mock import MagicMock
import pytest
-import tensorflow as tf
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow.lite.python import convert
from mlia.nn.tensorflow.tflite_compat import converter_error_data_pb2
@@ -19,11 +19,11 @@ from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
def test_not_fully_compatible_model_flex_ops() -> None:
"""Test models that requires TF_SELECT_OPS."""
- model = tf.keras.models.Sequential(
+ model = keras.models.Sequential(
[
- tf.keras.layers.Dense(units=1, input_shape=[1], batch_size=1),
- tf.keras.layers.Dense(units=16, activation="softsign"),
- tf.keras.layers.Dense(units=1),
+ keras.layers.Dense(units=1, input_shape=[1], batch_size=1),
+ keras.layers.Dense(units=16, activation="softsign"),
+ keras.layers.Dense(units=1),
]
)