From a34163c9d9a5cc0416bcaea2ebf8383bda9d505c Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 24 Nov 2022 08:34:38 +0000 Subject: Move TOSA checker functions into separate module - Create module "compat" for tosa_checker backend - Move TOSA checker functions into new module - Update tests Change-Id: Ia07034515fe43b2061b8892535067d21315cc721 --- src/mlia/backend/tosa_checker/compat.py | 69 ++++++++++++++++++++++++ src/mlia/devices/tosa/data_analysis.py | 2 +- src/mlia/devices/tosa/data_collection.py | 4 +- src/mlia/devices/tosa/handlers.py | 2 +- src/mlia/devices/tosa/operators.py | 66 ----------------------- src/mlia/devices/tosa/reporters.py | 2 +- tests/test_backend_tosa_compat.py | 86 ++++++++++++++++++++++++++++++ tests/test_devices_tosa_data_analysis.py | 2 +- tests/test_devices_tosa_data_collection.py | 2 +- tests/test_devices_tosa_operators.py | 85 ----------------------------- 10 files changed, 162 insertions(+), 158 deletions(-) create mode 100644 src/mlia/backend/tosa_checker/compat.py create mode 100644 tests/test_backend_tosa_compat.py delete mode 100644 tests/test_devices_tosa_operators.py diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py new file mode 100644 index 0000000..e1bcb24 --- /dev/null +++ b/src/mlia/backend/tosa_checker/compat.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA compatibility module.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import cast +from typing import Protocol + +from mlia.core.typing import PathOrFileLike + + +class TOSAChecker(Protocol): + """TOSA checker protocol.""" + + def is_tosa_compatible(self) -> bool: + """Return true if model is TOSA compatible.""" + + def _get_tosa_compatibility_for_ops(self) -> list[Any]: + """Return list of operators.""" + + +@dataclass +class Operator: + """Operator's TOSA compatibility info.""" + + location: str + name: str + is_tosa_compatible: bool + + +@dataclass +class TOSACompatibilityInfo: + """Models' TOSA compatibility information.""" + + tosa_compatible: bool + operators: list[Operator] + + +def get_tosa_compatibility_info( + tflite_model_path: PathOrFileLike, +) -> TOSACompatibilityInfo: + """Return list of the operators.""" + checker = get_tosa_checker(tflite_model_path) + + if checker is None: + raise Exception( + "TOSA checker is not available. " + "Please make sure that 'tosa-checker' backend is installed." + ) + + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + + return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + + +def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: + """Return instance of the TOSA checker.""" + try: + import tosa_checker as tc # pylint: disable=import-outside-toplevel + except ImportError: + return None + + checker = tc.TOSAChecker(str(tflite_model_path)) + return cast(TOSAChecker, checker) diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py index c18ac02..7cbd61d 100644 --- a/src/mlia/devices/tosa/data_analysis.py +++ b/src/mlia/devices/tosa/data_analysis.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from functools import singledispatchmethod +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor -from mlia.devices.tosa.operators import TOSACompatibilityInfo @dataclass diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py index 3809903..105c501 100644 --- a/src/mlia/devices/tosa/data_collection.py +++ b/src/mlia/devices/tosa/data_collection.py @@ -3,9 +3,9 @@ """TOSA data collection module.""" from pathlib import Path +from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.data_collection import ContextAwareDataCollector -from mlia.devices.tosa.operators import get_tosa_compatibility_info -from mlia.devices.tosa.operators import TOSACompatibilityInfo from mlia.nn.tensorflow.config import get_tflite_model from mlia.utils.logging import log_action diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py index 5f015c4..fc82657 100644 --- a/src/mlia/devices/tosa/handlers.py +++ b/src/mlia/devices/tosa/handlers.py @@ -6,12 +6,12 @@ from __future__ import annotations import logging +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler from mlia.core.typing import PathOrFileLike from mlia.devices.tosa.events import TOSAAdvisorEventHandler from mlia.devices.tosa.events import TOSAAdvisorStartedEvent -from mlia.devices.tosa.operators import TOSACompatibilityInfo from mlia.devices.tosa.reporters import tosa_formatters logger = logging.getLogger(__name__) diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py index 1e4581a..b75ceb0 100644 --- a/src/mlia/devices/tosa/operators.py +++ b/src/mlia/devices/tosa/operators.py @@ -1,72 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Operators module.""" -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any -from typing import cast -from typing import Protocol - -from mlia.core.typing import PathOrFileLike - - -class TOSAChecker(Protocol): - """TOSA checker protocol.""" - - def is_tosa_compatible(self) -> bool: - """Return true if model is TOSA compatible.""" - - def _get_tosa_compatibility_for_ops(self) -> list[Any]: - """Return list of operators.""" - - -@dataclass -class Operator: - """Operator's TOSA compatibility info.""" - - location: str - name: str - is_tosa_compatible: bool - - -@dataclass -class TOSACompatibilityInfo: - """Models' TOSA compatibility information.""" - - tosa_compatible: bool - operators: list[Operator] - - -def get_tosa_compatibility_info( - tflite_model_path: PathOrFileLike, -) -> TOSACompatibilityInfo: - """Return list of the operators.""" - checker = get_tosa_checker(tflite_model_path) - - if checker is None: - raise Exception( - "TOSA checker is not available. " - "Please make sure that 'tosa-checker' backend is installed." - ) - - ops = [ - Operator(item.location, item.name, item.is_tosa_compatible) - for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access - ] - - return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) - - -def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: - """Return instance of the TOSA checker.""" - try: - import tosa_checker as tc # pylint: disable=import-outside-toplevel - except ImportError: - return None - - checker = tc.TOSAChecker(str(tflite_model_path)) - return cast(TOSAChecker, checker) def report() -> None: diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py index 26c93fd..e5559ee 100644 --- a/src/mlia/devices/tosa/reporters.py +++ b/src/mlia/devices/tosa/reporters.py @@ -6,6 +6,7 @@ from __future__ import annotations from typing import Any from typing import Callable +from mlia.backend.tosa_checker.compat import Operator from mlia.core.advice_generation import Advice from mlia.core.reporters import report_advice from mlia.core.reporting import Cell @@ -16,7 +17,6 @@ from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table from mlia.devices.tosa.config import TOSAConfiguration -from mlia.devices.tosa.operators import Operator from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py new file mode 100644 index 0000000..4c4dc5a --- /dev/null +++ b/tests/test_backend_tosa_compat.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA compatibility.""" +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info +from mlia.backend.tosa_checker.compat import Operator +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo + + +def replace_get_tosa_checker_with_mock( + monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None +) -> None: + """Replace TOSA checker with mock.""" + monkeypatch.setattr( + "mlia.backend.tosa_checker.compat.get_tosa_checker", + MagicMock(return_value=mock), + ) + + +def test_compatibility_check_should_fail_if_checker_not_available( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path +) -> None: + """Test that compatibility check should fail if TOSA checker is not available.""" + replace_get_tosa_checker_with_mock(monkeypatch, None) + + with pytest.raises(Exception, match="TOSA checker is not available"): + get_tosa_compatibility_info(test_tflite_model) + + +@pytest.mark.parametrize( + "is_tosa_compatible, operators, expected_result", + [ + [ + True, + [], + TOSACompatibilityInfo(True, []), + ], + [ + True, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=True, + ) + ], + TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]), + ], + [ + False, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=False, + ) + ], + TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]), + ], + ], +) +def test_get_tosa_compatibility_info( + monkeypatch: pytest.MonkeyPatch, + test_tflite_model: Path, + is_tosa_compatible: bool, + operators: Any, + expected_result: TOSACompatibilityInfo, +) -> None: + """Test getting TOSA compatibility information.""" + mock_checker = MagicMock() + mock_checker.is_tosa_compatible.return_value = is_tosa_compatible + mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access + operators + ) + + replace_get_tosa_checker_with_mock(monkeypatch, mock_checker) + + assert get_tosa_compatibility_info(test_tflite_model) == expected_result diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py index ff95978..f2da691 100644 --- a/tests/test_devices_tosa_data_analysis.py +++ b/tests/test_devices_tosa_data_analysis.py @@ -5,12 +5,12 @@ from __future__ import annotations import pytest +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible from mlia.devices.tosa.data_analysis import TOSADataAnalyzer -from mlia.devices.tosa.operators import TOSACompatibilityInfo @pytest.mark.parametrize( diff --git a/tests/test_devices_tosa_data_collection.py b/tests/test_devices_tosa_data_collection.py index b9c0b4c..0c1eda1 100644 --- a/tests/test_devices_tosa_data_collection.py +++ b/tests/test_devices_tosa_data_collection.py @@ -6,9 +6,9 @@ from unittest.mock import MagicMock import pytest +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.context import ExecutionContext from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility -from mlia.devices.tosa.operators import TOSACompatibilityInfo def test_tosa_data_collection( diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py deleted file mode 100644 index d4372aa..0000000 --- a/tests/test_devices_tosa_operators.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for TOSA compatibility.""" -from __future__ import annotations - -from pathlib import Path -from types import SimpleNamespace -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from mlia.devices.tosa.operators import get_tosa_compatibility_info -from mlia.devices.tosa.operators import Operator -from mlia.devices.tosa.operators import TOSACompatibilityInfo - - -def replace_get_tosa_checker_with_mock( - monkeypatch: pytest.MonkeyPatch, mock: MagicMock | None -) -> None: - """Replace TOSA checker with mock.""" - monkeypatch.setattr( - "mlia.devices.tosa.operators.get_tosa_checker", MagicMock(return_value=mock) - ) - - -def test_compatibility_check_should_fail_if_checker_not_available( - monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path -) -> None: - """Test that compatibility check should fail if TOSA checker is not available.""" - replace_get_tosa_checker_with_mock(monkeypatch, None) - - with pytest.raises(Exception, match="TOSA checker is not available"): - get_tosa_compatibility_info(test_tflite_model) - - -@pytest.mark.parametrize( - "is_tosa_compatible, operators, expected_result", - [ - [ - True, - [], - TOSACompatibilityInfo(True, []), - ], - [ - True, - [ - SimpleNamespace( - location="op_location", - name="op_name", - is_tosa_compatible=True, - ) - ], - TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]), - ], - [ - False, - [ - SimpleNamespace( - location="op_location", - name="op_name", - is_tosa_compatible=False, - ) - ], - TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]), - ], - ], -) -def test_get_tosa_compatibility_info( - monkeypatch: pytest.MonkeyPatch, - test_tflite_model: Path, - is_tosa_compatible: bool, - operators: Any, - expected_result: TOSACompatibilityInfo, -) -> None: - """Test getting TOSA compatibility information.""" - mock_checker = MagicMock() - mock_checker.is_tosa_compatible.return_value = is_tosa_compatible - mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access - operators - ) - - replace_get_tosa_checker_with_mock(monkeypatch, mock_checker) - - assert get_tosa_compatibility_info(test_tflite_model) == expected_result -- cgit v1.2.1