aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-24 08:34:38 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-29 14:44:13 +0000
commita34163c9d9a5cc0416bcaea2ebf8383bda9d505c (patch)
tree304c01c607b3a93c250a38df53c417f62196b5fa
parent37959522a805a5e23c930ed79aac84920c3cb208 (diff)
downloadmlia-a34163c9d9a5cc0416bcaea2ebf8383bda9d505c.tar.gz
Move TOSA checker functions into separate module
- Create module "compat" for tosa_checker backend - Move TOSA checker functions into new module - Update tests Change-Id: Ia07034515fe43b2061b8892535067d21315cc721
-rw-r--r--src/mlia/backend/tosa_checker/compat.py69
-rw-r--r--src/mlia/devices/tosa/data_analysis.py2
-rw-r--r--src/mlia/devices/tosa/data_collection.py4
-rw-r--r--src/mlia/devices/tosa/handlers.py2
-rw-r--r--src/mlia/devices/tosa/operators.py66
-rw-r--r--src/mlia/devices/tosa/reporters.py2
-rw-r--r--tests/test_backend_tosa_compat.py (renamed from tests/test_devices_tosa_operators.py)9
-rw-r--r--tests/test_devices_tosa_data_analysis.py2
-rw-r--r--tests/test_devices_tosa_data_collection.py2
9 files changed, 81 insertions, 77 deletions
diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py
new file mode 100644
index 0000000..e1bcb24
--- /dev/null
+++ b/src/mlia/backend/tosa_checker/compat.py
@@ -0,0 +1,69 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA compatibility module."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any
+from typing import cast
+from typing import Protocol
+
+from mlia.core.typing import PathOrFileLike
+
+
+class TOSAChecker(Protocol):
+ """TOSA checker protocol."""
+
+ def is_tosa_compatible(self) -> bool:
+ """Return true if model is TOSA compatible."""
+
+ def _get_tosa_compatibility_for_ops(self) -> list[Any]:
+ """Return list of operators."""
+
+
+@dataclass
+class Operator:
+ """Operator's TOSA compatibility info."""
+
+ location: str
+ name: str
+ is_tosa_compatible: bool
+
+
+@dataclass
+class TOSACompatibilityInfo:
+ """Models' TOSA compatibility information."""
+
+ tosa_compatible: bool
+ operators: list[Operator]
+
+
+def get_tosa_compatibility_info(
+ tflite_model_path: PathOrFileLike,
+) -> TOSACompatibilityInfo:
+ """Return list of the operators."""
+ checker = get_tosa_checker(tflite_model_path)
+
+ if checker is None:
+ raise Exception(
+ "TOSA checker is not available. "
+ "Please make sure that 'tosa-checker' backend is installed."
+ )
+
+ ops = [
+ Operator(item.location, item.name, item.is_tosa_compatible)
+ for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access
+ ]
+
+ return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops)
+
+
+def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
+ """Return instance of the TOSA checker."""
+ try:
+ import tosa_checker as tc # pylint: disable=import-outside-toplevel
+ except ImportError:
+ return None
+
+ checker = tc.TOSAChecker(str(tflite_model_path))
+ return cast(TOSAChecker, checker)
diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py
index c18ac02..7cbd61d 100644
--- a/src/mlia/devices/tosa/data_analysis.py
+++ b/src/mlia/devices/tosa/data_analysis.py
@@ -4,10 +4,10 @@
from dataclasses import dataclass
from functools import singledispatchmethod
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
-from mlia.devices.tosa.operators import TOSACompatibilityInfo
@dataclass
diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py
index 3809903..105c501 100644
--- a/src/mlia/devices/tosa/data_collection.py
+++ b/src/mlia/devices/tosa/data_collection.py
@@ -3,9 +3,9 @@
"""TOSA data collection module."""
from pathlib import Path
+from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.data_collection import ContextAwareDataCollector
-from mlia.devices.tosa.operators import get_tosa_compatibility_info
-from mlia.devices.tosa.operators import TOSACompatibilityInfo
from mlia.nn.tensorflow.config import get_tflite_model
from mlia.utils.logging import log_action
diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py
index 5f015c4..fc82657 100644
--- a/src/mlia/devices/tosa/handlers.py
+++ b/src/mlia/devices/tosa/handlers.py
@@ -6,12 +6,12 @@ from __future__ import annotations
import logging
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.core.typing import PathOrFileLike
from mlia.devices.tosa.events import TOSAAdvisorEventHandler
from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
-from mlia.devices.tosa.operators import TOSACompatibilityInfo
from mlia.devices.tosa.reporters import tosa_formatters
logger = logging.getLogger(__name__)
diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py
index 1e4581a..b75ceb0 100644
--- a/src/mlia/devices/tosa/operators.py
+++ b/src/mlia/devices/tosa/operators.py
@@ -1,72 +1,6 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Operators module."""
-from __future__ import annotations
-
-from dataclasses import dataclass
-from typing import Any
-from typing import cast
-from typing import Protocol
-
-from mlia.core.typing import PathOrFileLike
-
-
-class TOSAChecker(Protocol):
- """TOSA checker protocol."""
-
- def is_tosa_compatible(self) -> bool:
- """Return true if model is TOSA compatible."""
-
- def _get_tosa_compatibility_for_ops(self) -> list[Any]:
- """Return list of operators."""
-
-
-@dataclass
-class Operator:
- """Operator's TOSA compatibility info."""
-
- location: str
- name: str
- is_tosa_compatible: bool
-
-
-@dataclass
-class TOSACompatibilityInfo:
- """Models' TOSA compatibility information."""
-
- tosa_compatible: bool
- operators: list[Operator]
-
-
-def get_tosa_compatibility_info(
- tflite_model_path: PathOrFileLike,
-) -> TOSACompatibilityInfo:
- """Return list of the operators."""
- checker = get_tosa_checker(tflite_model_path)
-
- if checker is None:
- raise Exception(
- "TOSA checker is not available. "
- "Please make sure that 'tosa-checker' backend is installed."
- )
-
- ops = [
- Operator(item.location, item.name, item.is_tosa_compatible)
- for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access
- ]
-
- return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops)
-
-
-def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
- """Return instance of the TOSA checker."""
- try:
- import tosa_checker as tc # pylint: disable=import-outside-toplevel
- except ImportError:
- return None
-
- checker = tc.TOSAChecker(str(tflite_model_path))
- return cast(TOSAChecker, checker)
def report() -> None:
diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py
index 26c93fd..e5559ee 100644
--- a/src/mlia/devices/tosa/reporters.py
+++ b/src/mlia/devices/tosa/reporters.py
@@ -6,6 +6,7 @@ from __future__ import annotations
from typing import Any
from typing import Callable
+from mlia.backend.tosa_checker.compat import Operator
from mlia.core.advice_generation import Advice
from mlia.core.reporters import report_advice
from mlia.core.reporting import Cell
@@ -16,7 +17,6 @@ from mlia.core.reporting import Report
from mlia.core.reporting import ReportItem
from mlia.core.reporting import Table
from mlia.devices.tosa.config import TOSAConfiguration
-from mlia.devices.tosa.operators import Operator
from mlia.utils.console import style_improvement
from mlia.utils.types import is_list_of
diff --git a/tests/test_devices_tosa_operators.py b/tests/test_backend_tosa_compat.py
index d4372aa..4c4dc5a 100644
--- a/tests/test_devices_tosa_operators.py
+++ b/tests/test_backend_tosa_compat.py
@@ -10,9 +10,9 @@ from unittest.mock import MagicMock
import pytest
-from mlia.devices.tosa.operators import get_tosa_compatibility_info
-from mlia.devices.tosa.operators import Operator
-from mlia.devices.tosa.operators import TOSACompatibilityInfo
+from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info
+from mlia.backend.tosa_checker.compat import Operator
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
def replace_get_tosa_checker_with_mock(
@@ -20,7 +20,8 @@ def replace_get_tosa_checker_with_mock(
) -> None:
"""Replace TOSA checker with mock."""
monkeypatch.setattr(
- "mlia.devices.tosa.operators.get_tosa_checker", MagicMock(return_value=mock)
+ "mlia.backend.tosa_checker.compat.get_tosa_checker",
+ MagicMock(return_value=mock),
)
diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py
index ff95978..f2da691 100644
--- a/tests/test_devices_tosa_data_analysis.py
+++ b/tests/test_devices_tosa_data_analysis.py
@@ -5,12 +5,12 @@ from __future__ import annotations
import pytest
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
-from mlia.devices.tosa.operators import TOSACompatibilityInfo
@pytest.mark.parametrize(
diff --git a/tests/test_devices_tosa_data_collection.py b/tests/test_devices_tosa_data_collection.py
index b9c0b4c..0c1eda1 100644
--- a/tests/test_devices_tosa_data_collection.py
+++ b/tests/test_devices_tosa_data_collection.py
@@ -6,9 +6,9 @@ from unittest.mock import MagicMock
import pytest
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.context import ExecutionContext
from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility
-from mlia.devices.tosa.operators import TOSACompatibilityInfo
def test_tosa_data_collection(