aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_tflite_compat.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r--tests/test_nn_tensorflow_tflite_compat.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py
index c330fdb..1bd4c34 100644
--- a/tests/test_nn_tensorflow_tflite_compat.py
+++ b/tests/test_nn_tensorflow_tflite_compat.py
@@ -8,8 +8,8 @@ from unittest.mock import MagicMock
import pytest
import tensorflow as tf
from tensorflow.lite.python import convert
-from tensorflow.lite.python.metrics import converter_error_data_pb2
+from mlia.nn.tensorflow.tflite_compat import converter_error_data_pb2
from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
@@ -21,7 +21,7 @@ def test_not_fully_compatible_model_flex_ops() -> None:
model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(units=1, input_shape=[1], batch_size=1),
- tf.keras.layers.Dense(units=16, activation="gelu"),
+ tf.keras.layers.Dense(units=16, activation="softsign"),
tf.keras.layers.Dense(units=1),
]
)
@@ -36,9 +36,9 @@ def test_not_fully_compatible_model_flex_ops() -> None:
conv_err = result.conversion_errors[0]
assert isinstance(conv_err, TFLiteConversionError)
- assert conv_err.message == "'tf.Erf' op is neither a custom op nor a flex op"
+ assert conv_err.message == "'tf.Softsign' op is neither a custom op nor a flex op"
assert conv_err.code == TFLiteConversionErrorCode.NEEDS_FLEX_OPS
- assert conv_err.operator == "tf.Erf"
+ assert conv_err.operator == "tf.Softsign"
assert len(conv_err.location) == 3