From ce9b17650d024886b24ad820f0f1815fc23b19f3 Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Mon, 7 Nov 2022 12:57:15 +0000 Subject: MLIA-701 Update dependencies - Update TensorFlow dependencies for x86_64 - Adapt unit tests to new TensorFlow version - Update linters (including pre-commit hooks) and fix issues - Use conditional import to fix tflite compat code for aarch64 Change-Id: I1a9b080b900ab65e38f7f2552562822bbfdcd259 --- src/mlia/backend/common.py | 2 +- src/mlia/core/common.py | 8 +++++++- src/mlia/core/helpers.py | 2 +- src/mlia/devices/cortexa/advice_generation.py | 2 +- src/mlia/devices/cortexa/data_analysis.py | 2 +- src/mlia/devices/ethosu/advice_generation.py | 2 +- src/mlia/devices/ethosu/data_analysis.py | 2 +- src/mlia/devices/tosa/advice_generation.py | 2 +- src/mlia/devices/tosa/data_analysis.py | 2 +- src/mlia/nn/tensorflow/tflite_compat.py | 10 +++++++++- src/mlia/utils/download.py | 2 +- 11 files changed, 25 insertions(+), 11 deletions(-) (limited to 'src') diff --git a/src/mlia/backend/common.py b/src/mlia/backend/common.py index 697c2a0..0f04553 100644 --- a/src/mlia/backend/common.py +++ b/src/mlia/backend/common.py @@ -205,7 +205,7 @@ class Backend(ABC): return self.variables[var_name] - return var_pattern.sub(var_value, str_val) # type: ignore + return var_pattern.sub(var_value, str_val) @classmethod def _parse_params( diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py index 63fb324..6c9dde1 100644 --- a/src/mlia/core/common.py +++ b/src/mlia/core/common.py @@ -29,7 +29,13 @@ class AdviceCategory(Flag): OPERATORS = auto() PERFORMANCE = auto() OPTIMIZATION = auto() - ALL = OPERATORS | PERFORMANCE | OPTIMIZATION + ALL = ( + # pylint: disable=unsupported-binary-operation + OPERATORS + | PERFORMANCE + | OPTIMIZATION + # pylint: enable=unsupported-binary-operation + ) @classmethod def from_string(cls, value: str) -> AdviceCategory: diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py index f0c4474..f4a9df6 100644 --- a/src/mlia/core/helpers.py +++ b/src/mlia/core/helpers.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for various helper classes.""" -# pylint: disable=no-self-use, unused-argument +# pylint: disable=unused-argument from __future__ import annotations from typing import Any diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py index 34c51f8..186f489 100644 --- a/src/mlia/devices/cortexa/advice_generation.py +++ b/src/mlia/devices/cortexa/advice_generation.py @@ -23,7 +23,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): ) @singledispatchmethod - def produce_advice(self, _data_item: DataItem) -> None: + def produce_advice(self, _data_item: DataItem) -> None: # type: ignore """Produce advice.""" @produce_advice.register diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py index 9f6d82b..6a82dd0 100644 --- a/src/mlia/devices/cortexa/data_analysis.py +++ b/src/mlia/devices/cortexa/data_analysis.py @@ -21,7 +21,7 @@ class CortexADataAnalyzer(FactExtractor): """Cortex-A data analyzer.""" @singledispatchmethod - def analyze_data(self, data_item: DataItem) -> None: + def analyze_data(self, data_item: DataItem) -> None: # type: ignore """Analyse the data.""" @analyze_data.register diff --git a/src/mlia/devices/ethosu/advice_generation.py b/src/mlia/devices/ethosu/advice_generation.py index 8a38d2c..1910460 100644 --- a/src/mlia/devices/ethosu/advice_generation.py +++ b/src/mlia/devices/ethosu/advice_generation.py @@ -22,7 +22,7 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): """Ethos-U advice producer.""" @singledispatchmethod - def produce_advice(self, data_item: DataItem) -> None: + def produce_advice(self, data_item: DataItem) -> None: # type: ignore """Produce advice.""" @produce_advice.register diff --git a/src/mlia/devices/ethosu/data_analysis.py b/src/mlia/devices/ethosu/data_analysis.py index 8d88cf7..70b6f65 100644 --- a/src/mlia/devices/ethosu/data_analysis.py +++ b/src/mlia/devices/ethosu/data_analysis.py @@ -83,7 +83,7 @@ class EthosUDataAnalyzer(FactExtractor): """Ethos-U data analyzer.""" @singledispatchmethod - def analyze_data(self, data_item: DataItem) -> None: + def analyze_data(self, data_item: DataItem) -> None: # type: ignore """Analyse the data.""" @analyze_data.register diff --git a/src/mlia/devices/tosa/advice_generation.py b/src/mlia/devices/tosa/advice_generation.py index 7adfcb9..a3d8011 100644 --- a/src/mlia/devices/tosa/advice_generation.py +++ b/src/mlia/devices/tosa/advice_generation.py @@ -15,7 +15,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer): """TOSA advice producer.""" @singledispatchmethod - def produce_advice(self, _data_item: DataItem) -> None: + def produce_advice(self, _data_item: DataItem) -> None: # type: ignore """Produce advice.""" @produce_advice.register diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py index aa696a5..c18ac02 100644 --- a/src/mlia/devices/tosa/data_analysis.py +++ b/src/mlia/devices/tosa/data_analysis.py @@ -24,7 +24,7 @@ class TOSADataAnalyzer(FactExtractor): """TOSA data analyzer.""" @singledispatchmethod - def analyze_data(self, data_item: DataItem) -> None: + def analyze_data(self, data_item: DataItem) -> None: # type: ignore """Analyse the data.""" @analyze_data.register diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py index 960a5c3..6f183ca 100644 --- a/src/mlia/nn/tensorflow/tflite_compat.py +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -11,12 +11,20 @@ from typing import Any from typing import cast from typing import List +import tensorflow as tf from tensorflow.lite.python import convert -from tensorflow.lite.python.metrics import converter_error_data_pb2 from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import redirect_raw_output +TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split(".")) +# pylint: disable=import-error,ungrouped-imports +if (TF_VERSION_MAJOR == 2 and TF_VERSION_MINOR > 7) or TF_VERSION_MAJOR > 2: + from tensorflow.lite.python.metrics import converter_error_data_pb2 +else: + from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2 +# pylint: enable=import-error,ungrouped-imports + logger = logging.getLogger(__name__) diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py index 9ef2d9e..c8d0b69 100644 --- a/src/mlia/utils/download.py +++ b/src/mlia/utils/download.py @@ -48,7 +48,7 @@ def download( chunk_size: int = 8192, ) -> None: """Download the file.""" - with requests.get(url, stream=True) as resp: + with requests.get(url, stream=True, timeout=10.0) as resp: resp.raise_for_status() content_chunks = resp.iter_content(chunk_size=chunk_size) -- cgit v1.2.1