diff options
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r-- | tests/test_nn_tensorflow_tflite_compat.py | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py new file mode 100644 index 0000000..c330fdb --- /dev/null +++ b/tests/test_nn_tensorflow_tflite_compat.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for tflite_compat module.""" +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +import tensorflow as tf +from tensorflow.lite.python import convert +from tensorflow.lite.python.metrics import converter_error_data_pb2 + +from mlia.nn.tensorflow.tflite_compat import TFLiteChecker +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError +from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode + + +def test_not_fully_compatible_model_flex_ops() -> None: + """Test models that requires TF_SELECT_OPS.""" + model = tf.keras.models.Sequential( + [ + tf.keras.layers.Dense(units=1, input_shape=[1], batch_size=1), + tf.keras.layers.Dense(units=16, activation="gelu"), + tf.keras.layers.Dense(units=1), + ] + ) + + checker = TFLiteChecker() + result = checker.check_compatibility(model) + + assert result.compatible is False + assert isinstance(result.conversion_exception, convert.ConverterError) + assert result.conversion_errors is not None + assert len(result.conversion_errors) == 1 + + conv_err = result.conversion_errors[0] + assert isinstance(conv_err, TFLiteConversionError) + assert conv_err.message == "'tf.Erf' op is neither a custom op nor a flex op" + assert conv_err.code == TFLiteConversionErrorCode.NEEDS_FLEX_OPS + assert conv_err.operator == "tf.Erf" + assert len(conv_err.location) == 3 + + +def _get_tflite_conversion_error( + error_message: str = "Conversion error", + custom_op: bool = False, + flex_op: bool = False, + unsupported_flow_v1: bool = False, + gpu_not_compatible: bool = False, + unknown_reason: bool = False, +) -> convert.ConverterError: + """Create TensorFlow Lite conversion error.""" + error_data = converter_error_data_pb2.ConverterErrorData + convert_error = convert.ConverterError(error_message) + + # pylint: disable=no-member + def _add_error(operator: str, error_code: int) -> None: + convert_error.append_error( + error_data( + operator=error_data.Operator(name=operator), + error_code=error_code, + error_message=error_message, + ) + ) + + if custom_op: + _add_error("custom_op", error_data.ERROR_NEEDS_CUSTOM_OPS) + + if flex_op: + _add_error("flex_op", error_data.ERROR_NEEDS_FLEX_OPS) + + if unsupported_flow_v1: + _add_error("flow_op", error_data.ERROR_UNSUPPORTED_CONTROL_FLOW_V1) + + if gpu_not_compatible: + _add_error("non_gpu_op", error_data.ERROR_GPU_NOT_COMPATIBLE) + + if unknown_reason: + _add_error("unknown_op", None) # type: ignore + # pylint: enable=no-member + + return convert_error + + +# pylint: disable=undefined-variable,unused-variable +@pytest.mark.parametrize( + "conversion_error, expected_result", + [ + (None, TFLiteCompatibilityInfo(compatible=True)), + ( + err := _get_tflite_conversion_error(custom_op=True), + TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + conversion_errors=[ + TFLiteConversionError( + message="Conversion error", + code=TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS, + operator="custom_op", + location=[], + ) + ], + ), + ), + ( + err := _get_tflite_conversion_error(flex_op=True), + TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + conversion_errors=[ + TFLiteConversionError( + message="Conversion error", + code=TFLiteConversionErrorCode.NEEDS_FLEX_OPS, + operator="flex_op", + location=[], + ) + ], + ), + ), + ( + err := _get_tflite_conversion_error(unknown_reason=True), + TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + conversion_errors=[ + TFLiteConversionError( + message="Conversion error", + code=TFLiteConversionErrorCode.UNKNOWN, + operator="unknown_op", + location=[], + ) + ], + ), + ), + ( + err := _get_tflite_conversion_error( + flex_op=True, + custom_op=True, + gpu_not_compatible=True, + unsupported_flow_v1=True, + ), + TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + conversion_errors=[ + TFLiteConversionError( + message="Conversion error", + code=TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS, + operator="custom_op", + location=[], + ), + TFLiteConversionError( + message="Conversion error", + code=TFLiteConversionErrorCode.NEEDS_FLEX_OPS, + operator="flex_op", + location=[], + ), + TFLiteConversionError( + message="Conversion error", + code=TFLiteConversionErrorCode.UNSUPPORTED_CONTROL_FLOW_V1, + operator="flow_op", + location=[], + ), + TFLiteConversionError( + message="Conversion error", + code=TFLiteConversionErrorCode.GPU_NOT_COMPATIBLE, + operator="non_gpu_op", + location=[], + ), + ], + ), + ), + ( + err := _get_tflite_conversion_error(), + TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + conversion_errors=[], + ), + ), + ( + err := ValueError("Some unknown issue"), + TFLiteCompatibilityInfo( + compatible=False, + conversion_exception=err, + ), + ), + ], +) +# pylint: enable=undefined-variable,unused-variable +def test_tflite_compatibility( + conversion_error: convert.ConverterError | ValueError | None, + expected_result: TFLiteCompatibilityInfo, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test TensorFlow Lite compatibility.""" + converter_mock = MagicMock() + + if conversion_error is not None: + converter_mock.convert.side_effect = conversion_error + + monkeypatch.setattr( + "mlia.nn.tensorflow.tflite_compat.get_tflite_converter", + lambda *args, **kwargs: converter_mock, + ) + + checker = TFLiteChecker() + result = checker.check_compatibility(MagicMock()) + assert result == expected_result |