aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-09 11:23:50 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-15 12:55:49 +0000
commitef73bb773df214f3f33f8e4ca7d276041106cad2 (patch)
tree313d5bbcea9574dd4fa026639443548766cf2b91 /src/mlia/nn
parentbb20d22509a304c76f849486fe15e3acd7667fb8 (diff)
downloadmlia-ef73bb773df214f3f33f8e4ca7d276041106cad2.tar.gz
MLIA-685 Warn about custom operators in SavedModel/Keras models
- Add new error types for the TensorFlow Lite compatibility check - Try to detect custom operators in SavedModel/Keras models - Add warning to the advice about models with custom operators Change-Id: I2f65474eecf2788110acc43585fa300eda80e21b
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py92
1 files changed, 85 insertions, 7 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
index 6f183ca..2b29879 100644
--- a/src/mlia/nn/tensorflow/tflite_compat.py
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -49,11 +49,20 @@ class TFLiteConversionError:
location: list[str]
+class TFLiteCompatibilityStatus(Enum):
+ """TensorFlow lite compatiblity status."""
+
+ COMPATIBLE = auto()
+ TFLITE_CONVERSION_ERROR = auto()
+ MODEL_WITH_CUSTOM_OP_ERROR = auto()
+ UNKNOWN_ERROR = auto()
+
+
@dataclass
class TFLiteCompatibilityInfo:
"""TensorFlow Lite compatibility information."""
- compatible: bool
+ status: TFLiteCompatibilityStatus
conversion_exception: Exception | None = None
conversion_errors: list[TFLiteConversionError] | None = None
@@ -64,6 +73,36 @@ class TFLiteCompatibilityInfo:
return [err.operator for err in self.conversion_errors if err.code == code]
+ @property
+ def compatible(self) -> bool:
+ """Return true if model compatible with the TensorFlow Lite format."""
+ return self.status == TFLiteCompatibilityStatus.COMPATIBLE
+
+ @property
+ def conversion_failed_with_errors(self) -> bool:
+ """Return true if conversion to TensorFlow Lite format failed."""
+ return self.status == TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR
+
+ @property
+ def conversion_failed_for_model_with_custom_ops(self) -> bool:
+ """Return true if conversion failed due to custom ops in the model."""
+ return self.status == TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR
+
+ @property
+ def check_failed_with_unknown_error(self) -> bool:
+ """Return true if check failed with unknown error."""
+ return self.status == TFLiteCompatibilityStatus.UNKNOWN_ERROR
+
+ @property
+ def required_custom_ops(self) -> list[str]:
+ """Return list of the custom ops reported during conversion."""
+ return self.unsupported_ops_by_code(TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS)
+
+ @property
+ def required_flex_ops(self) -> list[str]:
+ """Return list of the flex ops reported during conversion."""
+ return self.unsupported_ops_by_code(TFLiteConversionErrorCode.NEEDS_FLEX_OPS)
+
class TFLiteChecker:
"""Class for checking TensorFlow Lite compatibility."""
@@ -86,13 +125,15 @@ class TFLiteChecker:
):
converter.convert()
except convert.ConverterError as err:
- return self._process_exception(err)
+ return self._process_convert_error(err)
except Exception as err: # pylint: disable=broad-except
- return TFLiteCompatibilityInfo(compatible=False, conversion_exception=err)
- else:
- return TFLiteCompatibilityInfo(compatible=True)
+ return self._process_exception(err)
- def _process_exception(
+ return TFLiteCompatibilityInfo(
+ status=TFLiteCompatibilityStatus.COMPATIBLE,
+ )
+
+ def _process_convert_error(
self, err: convert.ConverterError
) -> TFLiteCompatibilityInfo:
"""Parse error details if possible."""
@@ -114,11 +155,48 @@ class TFLiteChecker:
]
return TFLiteCompatibilityInfo(
- compatible=False,
+ status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR,
conversion_exception=err,
conversion_errors=conversion_errors,
)
+ def _process_exception(self, err: Exception) -> TFLiteCompatibilityInfo:
+ """Process exception during conversion."""
+ status = TFLiteCompatibilityStatus.UNKNOWN_ERROR
+
+ if self._model_with_custom_op(err):
+ status = TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR
+
+ return TFLiteCompatibilityInfo(
+ status=status,
+ conversion_exception=err,
+ )
+
+ @staticmethod
+ def _model_with_custom_op(err: Exception) -> bool:
+ """Check if model could not be loaded because of custom ops."""
+ exc_attrs = [
+ (
+ ValueError,
+ [
+ "Unable to restore custom object",
+ "passed to the `custom_objects`",
+ ],
+ ),
+ (
+ FileNotFoundError,
+ [
+ "Op type not registered",
+ ],
+ ),
+ ]
+
+ return any(
+ any(msg in str(err) for msg in messages)
+ for exc_type, messages in exc_attrs
+ if isinstance(err, exc_type)
+ )
+
@staticmethod
def _convert_error_code(code: int) -> TFLiteConversionErrorCode:
"""Convert internal error codes."""