diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-09 11:23:50 +0000 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-15 12:55:49 +0000 |
commit | ef73bb773df214f3f33f8e4ca7d276041106cad2 (patch) | |
tree | 313d5bbcea9574dd4fa026639443548766cf2b91 /tests/test_nn_tensorflow_tflite_compat.py | |
parent | bb20d22509a304c76f849486fe15e3acd7667fb8 (diff) | |
download | mlia-ef73bb773df214f3f33f8e4ca7d276041106cad2.tar.gz |
MLIA-685 Warn about custom operators in SavedModel/Keras models
- Add new error types for the TensorFlow Lite compatibility
check
- Try to detect custom operators in SavedModel/Keras models
- Add warning to the advice about models with custom operators
Change-Id: I2f65474eecf2788110acc43585fa300eda80e21b
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r-- | tests/test_nn_tensorflow_tflite_compat.py | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py index 1bd4c34..f203125 100644 --- a/tests/test_nn_tensorflow_tflite_compat.py +++ b/tests/test_nn_tensorflow_tflite_compat.py @@ -12,6 +12,7 @@ from tensorflow.lite.python import convert from mlia.nn.tensorflow.tflite_compat import converter_error_data_pb2 from mlia.nn.tensorflow.tflite_compat import TFLiteChecker from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo +from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode @@ -87,11 +88,14 @@ def _get_tflite_conversion_error( @pytest.mark.parametrize( "conversion_error, expected_result", [ - (None, TFLiteCompatibilityInfo(compatible=True)), + ( + None, + TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE), + ), ( err := _get_tflite_conversion_error(custom_op=True), TFLiteCompatibilityInfo( - compatible=False, + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, conversion_exception=err, conversion_errors=[ TFLiteConversionError( @@ -106,7 +110,7 @@ def _get_tflite_conversion_error( ( err := _get_tflite_conversion_error(flex_op=True), TFLiteCompatibilityInfo( - compatible=False, + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, conversion_exception=err, conversion_errors=[ TFLiteConversionError( @@ -121,7 +125,7 @@ def _get_tflite_conversion_error( ( err := _get_tflite_conversion_error(unknown_reason=True), TFLiteCompatibilityInfo( - compatible=False, + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, conversion_exception=err, conversion_errors=[ TFLiteConversionError( @@ -141,7 +145,7 @@ def _get_tflite_conversion_error( unsupported_flow_v1=True, ), TFLiteCompatibilityInfo( - compatible=False, + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, conversion_exception=err, conversion_errors=[ TFLiteConversionError( @@ -174,7 +178,7 @@ def _get_tflite_conversion_error( ( err := _get_tflite_conversion_error(), TFLiteCompatibilityInfo( - compatible=False, + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, conversion_exception=err, conversion_errors=[], ), @@ -182,7 +186,21 @@ def _get_tflite_conversion_error( ( err := ValueError("Some unknown issue"), TFLiteCompatibilityInfo( - compatible=False, + status=TFLiteCompatibilityStatus.UNKNOWN_ERROR, + conversion_exception=err, + ), + ), + ( + err := ValueError("Unable to restore custom object"), + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR, + conversion_exception=err, + ), + ), + ( + err := FileNotFoundError("Op type not registered"), + TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR, conversion_exception=err, ), ), |