diff options
Diffstat (limited to 'src/mlia/nn')
-rw-r--r-- | src/mlia/nn/tensorflow/tflite_compat.py | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py index 960a5c3..6f183ca 100644 --- a/src/mlia/nn/tensorflow/tflite_compat.py +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -11,12 +11,20 @@ from typing import Any from typing import cast from typing import List +import tensorflow as tf from tensorflow.lite.python import convert -from tensorflow.lite.python.metrics import converter_error_data_pb2 from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import redirect_raw_output +TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split(".")) +# pylint: disable=import-error,ungrouped-imports +if (TF_VERSION_MAJOR == 2 and TF_VERSION_MINOR > 7) or TF_VERSION_MAJOR > 2: + from tensorflow.lite.python.metrics import converter_error_data_pb2 +else: + from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2 +# pylint: enable=import-error,ungrouped-imports + logger = logging.getLogger(__name__) |