diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mlia/backend/tosa_checker/compat.py | 69 | ||||
-rw-r--r-- | src/mlia/devices/tosa/data_analysis.py | 2 | ||||
-rw-r--r-- | src/mlia/devices/tosa/data_collection.py | 4 | ||||
-rw-r--r-- | src/mlia/devices/tosa/handlers.py | 2 | ||||
-rw-r--r-- | src/mlia/devices/tosa/operators.py | 66 | ||||
-rw-r--r-- | src/mlia/devices/tosa/reporters.py | 2 |
6 files changed, 74 insertions, 71 deletions
diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py new file mode 100644 index 0000000..e1bcb24 --- /dev/null +++ b/src/mlia/backend/tosa_checker/compat.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA compatibility module.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import cast +from typing import Protocol + +from mlia.core.typing import PathOrFileLike + + +class TOSAChecker(Protocol): + """TOSA checker protocol.""" + + def is_tosa_compatible(self) -> bool: + """Return true if model is TOSA compatible.""" + + def _get_tosa_compatibility_for_ops(self) -> list[Any]: + """Return list of operators.""" + + +@dataclass +class Operator: + """Operator's TOSA compatibility info.""" + + location: str + name: str + is_tosa_compatible: bool + + +@dataclass +class TOSACompatibilityInfo: + """Models' TOSA compatibility information.""" + + tosa_compatible: bool + operators: list[Operator] + + +def get_tosa_compatibility_info( + tflite_model_path: PathOrFileLike, +) -> TOSACompatibilityInfo: + """Return list of the operators.""" + checker = get_tosa_checker(tflite_model_path) + + if checker is None: + raise Exception( + "TOSA checker is not available. " + "Please make sure that 'tosa-checker' backend is installed." + ) + + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + + return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + + +def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: + """Return instance of the TOSA checker.""" + try: + import tosa_checker as tc # pylint: disable=import-outside-toplevel + except ImportError: + return None + + checker = tc.TOSAChecker(str(tflite_model_path)) + return cast(TOSAChecker, checker) diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py index c18ac02..7cbd61d 100644 --- a/src/mlia/devices/tosa/data_analysis.py +++ b/src/mlia/devices/tosa/data_analysis.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from functools import singledispatchmethod +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor -from mlia.devices.tosa.operators import TOSACompatibilityInfo @dataclass diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py index 3809903..105c501 100644 --- a/src/mlia/devices/tosa/data_collection.py +++ b/src/mlia/devices/tosa/data_collection.py @@ -3,9 +3,9 @@ """TOSA data collection module.""" from pathlib import Path +from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.data_collection import ContextAwareDataCollector -from mlia.devices.tosa.operators import get_tosa_compatibility_info -from mlia.devices.tosa.operators import TOSACompatibilityInfo from mlia.nn.tensorflow.config import get_tflite_model from mlia.utils.logging import log_action diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py index 5f015c4..fc82657 100644 --- a/src/mlia/devices/tosa/handlers.py +++ b/src/mlia/devices/tosa/handlers.py @@ -6,12 +6,12 @@ from __future__ import annotations import logging +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler from mlia.core.typing import PathOrFileLike from mlia.devices.tosa.events import TOSAAdvisorEventHandler from mlia.devices.tosa.events import TOSAAdvisorStartedEvent -from mlia.devices.tosa.operators import TOSACompatibilityInfo from mlia.devices.tosa.reporters import tosa_formatters logger = logging.getLogger(__name__) diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py index 1e4581a..b75ceb0 100644 --- a/src/mlia/devices/tosa/operators.py +++ b/src/mlia/devices/tosa/operators.py @@ -1,72 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Operators module.""" -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any -from typing import cast -from typing import Protocol - -from mlia.core.typing import PathOrFileLike - - -class TOSAChecker(Protocol): - """TOSA checker protocol.""" - - def is_tosa_compatible(self) -> bool: - """Return true if model is TOSA compatible.""" - - def _get_tosa_compatibility_for_ops(self) -> list[Any]: - """Return list of operators.""" - - -@dataclass -class Operator: - """Operator's TOSA compatibility info.""" - - location: str - name: str - is_tosa_compatible: bool - - -@dataclass -class TOSACompatibilityInfo: - """Models' TOSA compatibility information.""" - - tosa_compatible: bool - operators: list[Operator] - - -def get_tosa_compatibility_info( - tflite_model_path: PathOrFileLike, -) -> TOSACompatibilityInfo: - """Return list of the operators.""" - checker = get_tosa_checker(tflite_model_path) - - if checker is None: - raise Exception( - "TOSA checker is not available. " - "Please make sure that 'tosa-checker' backend is installed." - ) - - ops = [ - Operator(item.location, item.name, item.is_tosa_compatible) - for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access - ] - - return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) - - -def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: - """Return instance of the TOSA checker.""" - try: - import tosa_checker as tc # pylint: disable=import-outside-toplevel - except ImportError: - return None - - checker = tc.TOSAChecker(str(tflite_model_path)) - return cast(TOSAChecker, checker) def report() -> None: diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py index 26c93fd..e5559ee 100644 --- a/src/mlia/devices/tosa/reporters.py +++ b/src/mlia/devices/tosa/reporters.py @@ -6,6 +6,7 @@ from __future__ import annotations from typing import Any from typing import Callable +from mlia.backend.tosa_checker.compat import Operator from mlia.core.advice_generation import Advice from mlia.core.reporters import report_advice from mlia.core.reporting import Cell @@ -16,7 +17,6 @@ from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table from mlia.devices.tosa.config import TOSAConfiguration -from mlia.devices.tosa.operators import Operator from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of |