aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_tflite_compat.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-09 11:23:50 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-15 12:55:49 +0000
commitef73bb773df214f3f33f8e4ca7d276041106cad2 (patch)
tree313d5bbcea9574dd4fa026639443548766cf2b91 /tests/test_nn_tensorflow_tflite_compat.py
parentbb20d22509a304c76f849486fe15e3acd7667fb8 (diff)
downloadmlia-ef73bb773df214f3f33f8e4ca7d276041106cad2.tar.gz
MLIA-685 Warn about custom operators in SavedModel/Keras models
- Add new error types for the TensorFlow Lite compatibility check - Try to detect custom operators in SavedModel/Keras models - Add warning to the advice about models with custom operators Change-Id: I2f65474eecf2788110acc43585fa300eda80e21b
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_compat.py')
-rw-r--r--tests/test_nn_tensorflow_tflite_compat.py32
1 files changed, 25 insertions, 7 deletions
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py
index 1bd4c34..f203125 100644
--- a/tests/test_nn_tensorflow_tflite_compat.py
+++ b/tests/test_nn_tensorflow_tflite_compat.py
@@ -12,6 +12,7 @@ from tensorflow.lite.python import convert
from mlia.nn.tensorflow.tflite_compat import converter_error_data_pb2
from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
+from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionErrorCode
@@ -87,11 +88,14 @@ def _get_tflite_conversion_error(
@pytest.mark.parametrize(
"conversion_error, expected_result",
[
- (None, TFLiteCompatibilityInfo(compatible=True)),
+ (
+ None,
+ TFLiteCompatibilityInfo(status=TFLiteCompatibilityStatus.COMPATIBLE),
+ ),
(
err := _get_tflite_conversion_error(custom_op=True),
TFLiteCompatibilityInfo(
- compatible=False,
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
conversion_exception=err,
conversion_errors=[
TFLiteConversionError(
@@ -106,7 +110,7 @@ def _get_tflite_conversion_error(
(
err := _get_tflite_conversion_error(flex_op=True),
TFLiteCompatibilityInfo(
- compatible=False,
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
conversion_exception=err,
conversion_errors=[
TFLiteConversionError(
@@ -121,7 +125,7 @@ def _get_tflite_conversion_error(
(
err := _get_tflite_conversion_error(unknown_reason=True),
TFLiteCompatibilityInfo(
- compatible=False,
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
conversion_exception=err,
conversion_errors=[
TFLiteConversionError(
@@ -141,7 +145,7 @@ def _get_tflite_conversion_error(
unsupported_flow_v1=True,
),
TFLiteCompatibilityInfo(
- compatible=False,
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
conversion_exception=err,
conversion_errors=[
TFLiteConversionError(
@@ -174,7 +178,7 @@ def _get_tflite_conversion_error(
(
err := _get_tflite_conversion_error(),
TFLiteCompatibilityInfo(
- compatible=False,
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
conversion_exception=err,
conversion_errors=[],
),
@@ -182,7 +186,21 @@ def _get_tflite_conversion_error(
(
err := ValueError("Some unknown issue"),
TFLiteCompatibilityInfo(
- compatible=False,
+ status=TFLiteCompatibilityStatus.UNKNOWN_ERROR,
+ conversion_exception=err,
+ ),
+ ),
+ (
+ err := ValueError("Unable to restore custom object"),
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR,
+ conversion_exception=err,
+ ),
+ ),
+ (
+ err := FileNotFoundError("Op type not registered"),
+ TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR,
conversion_exception=err,
),
),