From ce9b17650d024886b24ad820f0f1815fc23b19f3 Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Mon, 7 Nov 2022 12:57:15 +0000 Subject: MLIA-701 Update dependencies - Update TensorFlow dependencies for x86_64 - Adapt unit tests to new TensorFlow version - Update linters (including pre-commit hooks) and fix issues - Use conditional import to fix tflite compat code for aarch64 Change-Id: I1a9b080b900ab65e38f7f2552562822bbfdcd259 --- src/mlia/nn/tensorflow/tflite_compat.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'src/mlia/nn') diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py index 960a5c3..6f183ca 100644 --- a/src/mlia/nn/tensorflow/tflite_compat.py +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -11,12 +11,20 @@ from typing import Any from typing import cast from typing import List +import tensorflow as tf from tensorflow.lite.python import convert -from tensorflow.lite.python.metrics import converter_error_data_pb2 from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import redirect_raw_output +TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split(".")) +# pylint: disable=import-error,ungrouped-imports +if (TF_VERSION_MAJOR == 2 and TF_VERSION_MINOR > 7) or TF_VERSION_MAJOR > 2: + from tensorflow.lite.python.metrics import converter_error_data_pb2 +else: + from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2 +# pylint: enable=import-error,ungrouped-imports + logger = logging.getLogger(__name__) -- cgit v1.2.1