aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_tflite_compat.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10 /tests/test_nn_tensorflow_tflite_compat.py
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r--tests/test_nn_tensorflow_tflite_compat.py210
1 files changed, 210 insertions, 0 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py
new file mode 100644
index 0000000..c330fdb
--- /dev/null
+++ b/tests/test_nn_tensorflow_tflite_compat.py
@@ -0,0 +1,210 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for tflite_compat module."""
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python import convert
+from tensorflow.lite.python.metrics import converter_error_data_pb2
+
+from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
+from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
+
+
+def test_not_fully_compatible_model_flex_ops() -> None:
+ """Test models that requires TF_SELECT_OPS."""
+ model = tf.keras.models.Sequential(
+ [
+ tf.keras.layers.Dense(units=1, input_shape=[1], batch_size=1),
+ tf.keras.layers.Dense(units=16, activation="gelu"),
+ tf.keras.layers.Dense(units=1),
+ ]
+ )
+
+ checker = TFLiteChecker()
+ result = checker.check_compatibility(model)
+
+ assert result.compatible is False
+ assert isinstance(result.conversion_exception, convert.ConverterError)
+ assert result.conversion_errors is not None
+ assert len(result.conversion_errors) == 1
+
+ conv_err = result.conversion_errors[0]
+ assert isinstance(conv_err, TFLiteConversionError)
+ assert conv_err.message == "'tf.Erf' op is neither a custom op nor a flex op"
+ assert conv_err.code == TFLiteConversionErrorCode.NEEDS_FLEX_OPS
+ assert conv_err.operator == "tf.Erf"
+ assert len(conv_err.location) == 3
+
+
+def _get_tflite_conversion_error(
+ error_message: str = "Conversion error",
+ custom_op: bool = False,
+ flex_op: bool = False,
+ unsupported_flow_v1: bool = False,
+ gpu_not_compatible: bool = False,
+ unknown_reason: bool = False,
+) -> convert.ConverterError:
+ """Create TensorFlow Lite conversion error."""
+ error_data = converter_error_data_pb2.ConverterErrorData
+ convert_error = convert.ConverterError(error_message)
+
+ # pylint: disable=no-member
+ def _add_error(operator: str, error_code: int) -> None:
+ convert_error.append_error(
+ error_data(
+ operator=error_data.Operator(name=operator),
+ error_code=error_code,
+ error_message=error_message,
+ )
+ )
+
+ if custom_op:
+ _add_error("custom_op", error_data.ERROR_NEEDS_CUSTOM_OPS)
+
+ if flex_op:
+ _add_error("flex_op", error_data.ERROR_NEEDS_FLEX_OPS)
+
+ if unsupported_flow_v1:
+ _add_error("flow_op", error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1)
+
+ if gpu_not_compatible:
+ _add_error("non_gpu_op", error_data.ERROR_GPU_NOT_COMPATIBLE)
+
+ if unknown_reason:
+ _add_error("unknown_op", None) # type: ignore
+ # pylint: enable=no-member
+
+ return convert_error
+
+
+# pylint: disable=undefined-variable,unused-variable
+@pytest.mark.parametrize(
+ "conversion_error, expected_result",
+ [
+ (None, TFLiteCompatibilityInfo(compatible=True)),
+ (
+ err := _get_tflite_conversion_error(custom_op=True),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS,
+ operator="custom_op",
+ location=[],
+ )
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(flex_op=True),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_FLEX_OPS,
+ operator="flex_op",
+ location=[],
+ )
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(unknown_reason=True),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.UNKNOWN,
+ operator="unknown_op",
+ location=[],
+ )
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(
+ flex_op=True,
+ custom_op=True,
+ gpu_not_compatible=True,
+ unsupported_flow_v1=True,
+ ),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS,
+ operator="custom_op",
+ location=[],
+ ),
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.NEEDS_FLEX_OPS,
+ operator="flex_op",
+ location=[],
+ ),
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1,
+ operator="flow_op",
+ location=[],
+ ),
+ TFLiteConversionError(
+ message="Conversion error",
+ code=TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE,
+ operator="non_gpu_op",
+ location=[],
+ ),
+ ],
+ ),
+ ),
+ (
+ err := _get_tflite_conversion_error(),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ conversion_errors=[],
+ ),
+ ),
+ (
+ err := ValueError("Some unknown issue"),
+ TFLiteCompatibilityInfo(
+ compatible=False,
+ conversion_exception=err,
+ ),
+ ),
+ ],
+)
+# pylint: enable=undefined-variable,unused-variable
+def test_tflite_compatibility(
+ conversion_error: convert.ConverterError | ValueError | None,
+ expected_result: TFLiteCompatibilityInfo,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test TensorFlow Lite compatibility."""
+ converter_mock = MagicMock()
+
+ if conversion_error is not None:
+ converter_mock.convert.side_effect = conversion_error
+
+ monkeypatch.setattr(
+ "mlia.nn.tensorflow.tflite_compat.get_tflite_converter",
+ lambda *args, **kwargs: converter_mock,
+ )
+
+ checker = TFLiteChecker()
+ result = checker.check_compatibility(MagicMock())
+ assert result == expected_result