From a34163c9d9a5cc0416bcaea2ebf8383bda9d505c Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 24 Nov 2022 08:34:38 +0000 Subject: Move TOSA checker functions into separate module - Create module "compat" for tosa_checker backend - Move TOSA checker functions into new module - Update tests Change-Id: Ia07034515fe43b2061b8892535067d21315cc721 --- src/mlia/devices/tosa/data_analysis.py | 2 +- src/mlia/devices/tosa/data_collection.py | 4 +- src/mlia/devices/tosa/handlers.py | 2 +- src/mlia/devices/tosa/operators.py | 66 -------------------------------- src/mlia/devices/tosa/reporters.py | 2 +- 5 files changed, 5 insertions(+), 71 deletions(-) (limited to 'src/mlia/devices') diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py index c18ac02..7cbd61d 100644 --- a/src/mlia/devices/tosa/data_analysis.py +++ b/src/mlia/devices/tosa/data_analysis.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from functools import singledispatchmethod +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.common import DataItem from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor -from mlia.devices.tosa.operators import TOSACompatibilityInfo @dataclass diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py index 3809903..105c501 100644 --- a/src/mlia/devices/tosa/data_collection.py +++ b/src/mlia/devices/tosa/data_collection.py @@ -3,9 +3,9 @@ """TOSA data collection module.""" from pathlib import Path +from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.data_collection import ContextAwareDataCollector -from mlia.devices.tosa.operators import get_tosa_compatibility_info -from mlia.devices.tosa.operators import TOSACompatibilityInfo from mlia.nn.tensorflow.config import get_tflite_model from mlia.utils.logging import log_action diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py index 5f015c4..fc82657 100644 --- a/src/mlia/devices/tosa/handlers.py +++ b/src/mlia/devices/tosa/handlers.py @@ -6,12 +6,12 @@ from __future__ import annotations import logging +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler from mlia.core.typing import PathOrFileLike from mlia.devices.tosa.events import TOSAAdvisorEventHandler from mlia.devices.tosa.events import TOSAAdvisorStartedEvent -from mlia.devices.tosa.operators import TOSACompatibilityInfo from mlia.devices.tosa.reporters import tosa_formatters logger = logging.getLogger(__name__) diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py index 1e4581a..b75ceb0 100644 --- a/src/mlia/devices/tosa/operators.py +++ b/src/mlia/devices/tosa/operators.py @@ -1,72 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Operators module.""" -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any -from typing import cast -from typing import Protocol - -from mlia.core.typing import PathOrFileLike - - -class TOSAChecker(Protocol): - """TOSA checker protocol.""" - - def is_tosa_compatible(self) -> bool: - """Return true if model is TOSA compatible.""" - - def _get_tosa_compatibility_for_ops(self) -> list[Any]: - """Return list of operators.""" - - -@dataclass -class Operator: - """Operator's TOSA compatibility info.""" - - location: str - name: str - is_tosa_compatible: bool - - -@dataclass -class TOSACompatibilityInfo: - """Models' TOSA compatibility information.""" - - tosa_compatible: bool - operators: list[Operator] - - -def get_tosa_compatibility_info( - tflite_model_path: PathOrFileLike, -) -> TOSACompatibilityInfo: - """Return list of the operators.""" - checker = get_tosa_checker(tflite_model_path) - - if checker is None: - raise Exception( - "TOSA checker is not available. " - "Please make sure that 'tosa-checker' backend is installed." - ) - - ops = [ - Operator(item.location, item.name, item.is_tosa_compatible) - for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access - ] - - return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) - - -def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: - """Return instance of the TOSA checker.""" - try: - import tosa_checker as tc # pylint: disable=import-outside-toplevel - except ImportError: - return None - - checker = tc.TOSAChecker(str(tflite_model_path)) - return cast(TOSAChecker, checker) def report() -> None: diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py index 26c93fd..e5559ee 100644 --- a/src/mlia/devices/tosa/reporters.py +++ b/src/mlia/devices/tosa/reporters.py @@ -6,6 +6,7 @@ from __future__ import annotations from typing import Any from typing import Callable +from mlia.backend.tosa_checker.compat import Operator from mlia.core.advice_generation import Advice from mlia.core.reporters import report_advice from mlia.core.reporting import Cell @@ -16,7 +17,6 @@ from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table from mlia.devices.tosa.config import TOSAConfiguration -from mlia.devices.tosa.operators import Operator from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of -- cgit v1.2.1