diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-11-07 12:57:15 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-11-11 12:10:26 +0000 |
commit | ce9b17650d024886b24ad820f0f1815fc23b19f3 (patch) | |
tree | a7d113f751b8856aabcd021464edec16e23ba6f8 /tests/test_nn_tensorflow_tflite_compat.py | |
parent | e40a7adadd254e29d71af38f69a0a20ff4871eef (diff) | |
download | mlia-ce9b17650d024886b24ad820f0f1815fc23b19f3.tar.gz |
MLIA-701 Update dependencies
- Update TensorFlow dependencies for x86_64
- Adapt unit tests to new TensorFlow version
- Update linters (including pre-commit hooks) and fix issues
- Use conditional import to fix tflite compat code for aarch64
Change-Id: I1a9b080b900ab65e38f7f2552562822bbfdcd259
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r-- | tests/test_nn_tensorflow_tflite_compat.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py index c330fdb..1bd4c34 100644 --- a/tests/test_nn_tensorflow_tflite_compat.py +++ b/tests/test_nn_tensorflow_tflite_compat.py @@ -8,8 +8,8 @@ from unittest.mock import MagicMock import pytest import tensorflow as tf from tensorflow.lite.python import convert -from tensorflow.lite.python.metrics import converter_error_data_pb2 +from mlia.nn.tensorflow.tflite_compat import converter_error_data_pb2 from mlia.nn.tensorflow.tflite_compat import TFLiteChecker from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError @@ -21,7 +21,7 @@ def test_not_fully_compatible_model_flex_ops() -> None: model = tf.keras.models.Sequential( [ tf.keras.layers.Dense(units=1, input_shape=[1], batch_size=1), - tf.keras.layers.Dense(units=16, activation="gelu"), + tf.keras.layers.Dense(units=16, activation="softsign"), tf.keras.layers.Dense(units=1), ] ) @@ -36,9 +36,9 @@ def test_not_fully_compatible_model_flex_ops() -> None: conv_err = result.conversion_errors[0] assert isinstance(conv_err, TFLiteConversionError) - assert conv_err.message == "'tf.Erf' op is neither a custom op nor a flex op" + assert conv_err.message == "'tf.Softsign' op is neither a custom op nor a flex op" assert conv_err.code == TFLiteConversionErrorCode.NEEDS_FLEX_OPS - assert conv_err.operator == "tf.Erf" + assert conv_err.operator == "tf.Softsign" assert len(conv_err.location) == 3 |