diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_compat.py')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_compat.py | 92 |
1 files changed, 85 insertions, 7 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py index 6f183ca..2b29879 100644 --- a/src/mlia/nn/tensorflow/tflite_compat.py +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -49,11 +49,20 @@ class TFLiteConversionError: location: list[str] +class TFLiteCompatibilityStatus(Enum): + """TensorFlow lite compatiblity status.""" + + COMPATIBLE = auto() + TFLITE_CONVERSION_ERROR = auto() + MODEL_WITH_CUSTOM_OP_ERROR = auto() + UNKNOWN_ERROR = auto() + + @dataclass class TFLiteCompatibilityInfo: """TensorFlow Lite compatibility information.""" - compatible: bool + status: TFLiteCompatibilityStatus conversion_exception: Exception | None = None conversion_errors: list[TFLiteConversionError] | None = None @@ -64,6 +73,36 @@ class TFLiteCompatibilityInfo: return [err.operator for err in self.conversion_errors if err.code == code] + @property + def compatible(self) -> bool: + """Return true if model compatible with the TensorFlow Lite format.""" + return self.status == TFLiteCompatibilityStatus.COMPATIBLE + + @property + def conversion_failed_with_errors(self) -> bool: + """Return true if conversion to TensorFlow Lite format failed.""" + return self.status == TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR + + @property + def conversion_failed_for_model_with_custom_ops(self) -> bool: + """Return true if conversion failed due to custom ops in the model.""" + return self.status == TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR + + @property + def check_failed_with_unknown_error(self) -> bool: + """Return true if check failed with unknown error.""" + return self.status == TFLiteCompatibilityStatus.UNKNOWN_ERROR + + @property + def required_custom_ops(self) -> list[str]: + """Return list of the custom ops reported during conversion.""" + return self.unsupported_ops_by_code(TFLiteConversionErrorCode.NEEDS_CUSTOM_OPS) + + @property + def required_flex_ops(self) -> list[str]: + """Return list of the flex ops reported during conversion.""" + return self.unsupported_ops_by_code(TFLiteConversionErrorCode.NEEDS_FLEX_OPS) + class TFLiteChecker: """Class for checking TensorFlow Lite compatibility.""" @@ -86,13 +125,15 @@ class TFLiteChecker: ): converter.convert() except convert.ConverterError as err: - return self._process_exception(err) + return self._process_convert_error(err) except Exception as err: # pylint: disable=broad-except - return TFLiteCompatibilityInfo(compatible=False, conversion_exception=err) - else: - return TFLiteCompatibilityInfo(compatible=True) + return self._process_exception(err) - def _process_exception( + return TFLiteCompatibilityInfo( + status=TFLiteCompatibilityStatus.COMPATIBLE, + ) + + def _process_convert_error( self, err: convert.ConverterError ) -> TFLiteCompatibilityInfo: """Parse error details if possible.""" @@ -114,11 +155,48 @@ class TFLiteChecker: ] return TFLiteCompatibilityInfo( - compatible=False, + status=TFLiteCompatibilityStatus.TFLITE_CONVERSION_ERROR, conversion_exception=err, conversion_errors=conversion_errors, ) + def _process_exception(self, err: Exception) -> TFLiteCompatibilityInfo: + """Process exception during conversion.""" + status = TFLiteCompatibilityStatus.UNKNOWN_ERROR + + if self._model_with_custom_op(err): + status = TFLiteCompatibilityStatus.MODEL_WITH_CUSTOM_OP_ERROR + + return TFLiteCompatibilityInfo( + status=status, + conversion_exception=err, + ) + + @staticmethod + def _model_with_custom_op(err: Exception) -> bool: + """Check if model could not be loaded because of custom ops.""" + exc_attrs = [ + ( + ValueError, + [ + "Unable to restore custom object", + "passed to the `custom_objects`", + ], + ), + ( + FileNotFoundError, + [ + "Op type not registered", + ], + ), + ] + + return any( + any(msg in str(err) for msg in messages) + for exc_type, messages in exc_attrs + if isinstance(err, exc_type) + ) + @staticmethod def _convert_error_code(code: int) -> TFLiteConversionErrorCode: """Convert internal error codes.""" |