aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
index 960a5c3..6f183ca 100644
--- a/src/mlia/nn/tensorflow/tflite_compat.py
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -11,12 +11,20 @@ from typing import Any
from typing import cast
from typing import List
+import tensorflow as tf
from tensorflow.lite.python import convert
-from tensorflow.lite.python.metrics import converter_error_data_pb2
from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import redirect_raw_output
+TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split("."))
+# pylint: disable=import-error,ungrouped-imports
+if (TF_VERSION_MAJOR == 2 and TF_VERSION_MINOR > 7) or TF_VERSION_MAJOR > 2:
+ from tensorflow.lite.python.metrics import converter_error_data_pb2
+else:
+ from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
+# pylint: enable=import-error,ungrouped-imports
+
logger = logging.getLogger(__name__)