aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-11-07 12:57:15 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-11-11 12:10:26 +0000
commitce9b17650d024886b24ad820f0f1815fc23b19f3 (patch)
treea7d113f751b8856aabcd021464edec16e23ba6f8 /src
parente40a7adadd254e29d71af38f69a0a20ff4871eef (diff)
downloadmlia-ce9b17650d024886b24ad820f0f1815fc23b19f3.tar.gz
MLIA-701 Update dependencies
- Update TensorFlow dependencies for x86_64 - Adapt unit tests to new TensorFlow version - Update linters (including pre-commit hooks) and fix issues - Use conditional import to fix tflite compat code for aarch64 Change-Id: I1a9b080b900ab65e38f7f2552562822bbfdcd259
Diffstat (limited to 'src')
-rw-r--r--src/mlia/backend/common.py2
-rw-r--r--src/mlia/core/common.py8
-rw-r--r--src/mlia/core/helpers.py2
-rw-r--r--src/mlia/devices/cortexa/advice_generation.py2
-rw-r--r--src/mlia/devices/cortexa/data_analysis.py2
-rw-r--r--src/mlia/devices/ethosu/advice_generation.py2
-rw-r--r--src/mlia/devices/ethosu/data_analysis.py2
-rw-r--r--src/mlia/devices/tosa/advice_generation.py2
-rw-r--r--src/mlia/devices/tosa/data_analysis.py2
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py10
-rw-r--r--src/mlia/utils/download.py2
11 files changed, 25 insertions, 11 deletions
diff --git a/src/mlia/backend/common.py b/src/mlia/backend/common.py
index 697c2a0..0f04553 100644
--- a/src/mlia/backend/common.py
+++ b/src/mlia/backend/common.py
@@ -205,7 +205,7 @@ class Backend(ABC):
return self.variables[var_name]
- return var_pattern.sub(var_value, str_val) # type: ignore
+ return var_pattern.sub(var_value, str_val)
@classmethod
def _parse_params(
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index 63fb324..6c9dde1 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -29,7 +29,13 @@ class AdviceCategory(Flag):
OPERATORS = auto()
PERFORMANCE = auto()
OPTIMIZATION = auto()
- ALL = OPERATORS | PERFORMANCE | OPTIMIZATION
+ ALL = (
+ # pylint: disable=unsupported-binary-operation
+ OPERATORS
+ | PERFORMANCE
+ | OPTIMIZATION
+ # pylint: enable=unsupported-binary-operation
+ )
@classmethod
def from_string(cls, value: str) -> AdviceCategory:
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
index f0c4474..f4a9df6 100644
--- a/src/mlia/core/helpers.py
+++ b/src/mlia/core/helpers.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
-# pylint: disable=no-self-use, unused-argument
+# pylint: disable=unused-argument
from __future__ import annotations
from typing import Any
diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py
index 34c51f8..186f489 100644
--- a/src/mlia/devices/cortexa/advice_generation.py
+++ b/src/mlia/devices/cortexa/advice_generation.py
@@ -23,7 +23,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@singledispatchmethod
- def produce_advice(self, _data_item: DataItem) -> None:
+ def produce_advice(self, _data_item: DataItem) -> None: # type: ignore
"""Produce advice."""
@produce_advice.register
diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py
index 9f6d82b..6a82dd0 100644
--- a/src/mlia/devices/cortexa/data_analysis.py
+++ b/src/mlia/devices/cortexa/data_analysis.py
@@ -21,7 +21,7 @@ class CortexADataAnalyzer(FactExtractor):
"""Cortex-A data analyzer."""
@singledispatchmethod
- def analyze_data(self, data_item: DataItem) -> None:
+ def analyze_data(self, data_item: DataItem) -> None: # type: ignore
"""Analyse the data."""
@analyze_data.register
diff --git a/src/mlia/devices/ethosu/advice_generation.py b/src/mlia/devices/ethosu/advice_generation.py
index 8a38d2c..1910460 100644
--- a/src/mlia/devices/ethosu/advice_generation.py
+++ b/src/mlia/devices/ethosu/advice_generation.py
@@ -22,7 +22,7 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
"""Ethos-U advice producer."""
@singledispatchmethod
- def produce_advice(self, data_item: DataItem) -> None:
+ def produce_advice(self, data_item: DataItem) -> None: # type: ignore
"""Produce advice."""
@produce_advice.register
diff --git a/src/mlia/devices/ethosu/data_analysis.py b/src/mlia/devices/ethosu/data_analysis.py
index 8d88cf7..70b6f65 100644
--- a/src/mlia/devices/ethosu/data_analysis.py
+++ b/src/mlia/devices/ethosu/data_analysis.py
@@ -83,7 +83,7 @@ class EthosUDataAnalyzer(FactExtractor):
"""Ethos-U data analyzer."""
@singledispatchmethod
- def analyze_data(self, data_item: DataItem) -> None:
+ def analyze_data(self, data_item: DataItem) -> None: # type: ignore
"""Analyse the data."""
@analyze_data.register
diff --git a/src/mlia/devices/tosa/advice_generation.py b/src/mlia/devices/tosa/advice_generation.py
index 7adfcb9..a3d8011 100644
--- a/src/mlia/devices/tosa/advice_generation.py
+++ b/src/mlia/devices/tosa/advice_generation.py
@@ -15,7 +15,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer):
"""TOSA advice producer."""
@singledispatchmethod
- def produce_advice(self, _data_item: DataItem) -> None:
+ def produce_advice(self, _data_item: DataItem) -> None: # type: ignore
"""Produce advice."""
@produce_advice.register
diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py
index aa696a5..c18ac02 100644
--- a/src/mlia/devices/tosa/data_analysis.py
+++ b/src/mlia/devices/tosa/data_analysis.py
@@ -24,7 +24,7 @@ class TOSADataAnalyzer(FactExtractor):
"""TOSA data analyzer."""
@singledispatchmethod
- def analyze_data(self, data_item: DataItem) -> None:
+ def analyze_data(self, data_item: DataItem) -> None: # type: ignore
"""Analyse the data."""
@analyze_data.register
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
index 960a5c3..6f183ca 100644
--- a/src/mlia/nn/tensorflow/tflite_compat.py
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -11,12 +11,20 @@ from typing import Any
from typing import cast
from typing import List
+import tensorflow as tf
from tensorflow.lite.python import convert
-from tensorflow.lite.python.metrics import converter_error_data_pb2
from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import redirect_raw_output
+TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split("."))
+# pylint: disable=import-error,ungrouped-imports
+if (TF_VERSION_MAJOR == 2 and TF_VERSION_MINOR > 7) or TF_VERSION_MAJOR > 2:
+ from tensorflow.lite.python.metrics import converter_error_data_pb2
+else:
+ from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2
+# pylint: enable=import-error,ungrouped-imports
+
logger = logging.getLogger(__name__)
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
index 9ef2d9e..c8d0b69 100644
--- a/src/mlia/utils/download.py
+++ b/src/mlia/utils/download.py
@@ -48,7 +48,7 @@ def download(
chunk_size: int = 8192,
) -> None:
"""Download the file."""
- with requests.get(url, stream=True) as resp:
+ with requests.get(url, stream=True, timeout=10.0) as resp:
resp.raise_for_status()
content_chunks = resp.iter_content(chunk_size=chunk_size)