aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-11-07 12:57:15 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-11-11 12:10:26 +0000
commitce9b17650d024886b24ad820f0f1815fc23b19f3 (patch)
treea7d113f751b8856aabcd021464edec16e23ba6f8 /src/mlia/nn
parente40a7adadd254e29d71af38f69a0a20ff4871eef (diff)
downloadmlia-ce9b17650d024886b24ad820f0f1815fc23b19f3.tar.gz
MLIA-701 Update dependencies
- Update TensorFlow dependencies for x86_64 - Adapt unit tests to new TensorFlow version - Update linters (including pre-commit hooks) and fix issues - Use conditional import to fix tflite compat code for aarch64 Change-Id: I1a9b080b900ab65e38f7f2552562822bbfdcd259
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
index 960a5c3..6f183ca 100644
--- a/src/mlia/nn/tensorflow/tflite_compat.py
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -11,12 +11,20 @@ from typing import Any
from typing import cast
from typing import List
+import tensorflow as tf
from tensorflow.lite.python import convert
-from tensorflow.lite.python.metrics import converter_error_data_pb2
from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import redirect_raw_output
+TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split("."))
+# pylint: disable=import-error,ungrouped-imports
+if (TF_VERSION_MAJOR == 2 and TF_VERSION_MINOR > 7) or TF_VERSION_MAJOR > 2:
+ from tensorflow.lite.python.metrics import converter_error_data_pb2
+else:
+ from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
+# pylint: enable=import-error,ungrouped-imports
+
logger = logging.getLogger(__name__)