diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-18 17:21:09 +0000 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-11-29 14:44:13 +0000 |
commit | 6a88ee5315b4ce5b023370c1e55e48bf9f2b6f67 (patch) | |
tree | 88edabf90228724f4fe2944b0ab23859d824a880 /src/mlia/devices/tosa | |
parent | a34163c9d9a5cc0416bcaea2ebf8383bda9d505c (diff) | |
download | mlia-6a88ee5315b4ce5b023370c1e55e48bf9f2b6f67.tar.gz |
Rename modules
- Rename module "mlia.devices" into "mlia.target"
- Rename module "mlia.target.ethosu" into "mlia.target.ethos_u"
- Rename module "mlia.target.cortexa" into "mlia.target.cortex_a"
- Rename and update tests
Change-Id: I6dca7c8646d881f739fb6b5914d1cc7e45e63dc2
Diffstat (limited to 'src/mlia/devices/tosa')
-rw-r--r-- | src/mlia/devices/tosa/__init__.py | 3 | ||||
-rw-r--r-- | src/mlia/devices/tosa/advice_generation.py | 40 | ||||
-rw-r--r-- | src/mlia/devices/tosa/advisor.py | 94 | ||||
-rw-r--r-- | src/mlia/devices/tosa/config.py | 19 | ||||
-rw-r--r-- | src/mlia/devices/tosa/data_analysis.py | 36 | ||||
-rw-r--r-- | src/mlia/devices/tosa/data_collection.py | 30 | ||||
-rw-r--r-- | src/mlia/devices/tosa/events.py | 24 | ||||
-rw-r--r-- | src/mlia/devices/tosa/handlers.py | 36 | ||||
-rw-r--r-- | src/mlia/devices/tosa/operators.py | 11 | ||||
-rw-r--r-- | src/mlia/devices/tosa/reporters.py | 83 |
10 files changed, 0 insertions, 376 deletions
diff --git a/src/mlia/devices/tosa/__init__.py b/src/mlia/devices/tosa/__init__.py deleted file mode 100644 index 762c831..0000000 --- a/src/mlia/devices/tosa/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA target module.""" diff --git a/src/mlia/devices/tosa/advice_generation.py b/src/mlia/devices/tosa/advice_generation.py deleted file mode 100644 index a3d8011..0000000 --- a/src/mlia/devices/tosa/advice_generation.py +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA advice generation.""" -from functools import singledispatchmethod - -from mlia.core.advice_generation import advice_category -from mlia.core.advice_generation import FactBasedAdviceProducer -from mlia.core.common import AdviceCategory -from mlia.core.common import DataItem -from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible -from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible - - -class TOSAAdviceProducer(FactBasedAdviceProducer): - """TOSA advice producer.""" - - @singledispatchmethod - def produce_advice(self, _data_item: DataItem) -> None: # type: ignore - """Produce advice.""" - - @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) - def handle_model_is_tosa_compatible( - self, _data_item: ModelIsTOSACompatible - ) -> None: - """Advice for TOSA compatibility.""" - self.add_advice(["Model is fully TOSA compatible."]) - - @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) - def handle_model_is_not_tosa_compatible( - self, _data_item: ModelIsNotTOSACompatible - ) -> None: - """Advice for TOSA compatibility.""" - self.add_advice( - [ - "Some operators in the model are not TOSA compatible. " - "Please, refer to the operators table for more information." - ] - ) diff --git a/src/mlia/devices/tosa/advisor.py b/src/mlia/devices/tosa/advisor.py deleted file mode 100644 index 53dfa87..0000000 --- a/src/mlia/devices/tosa/advisor.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA advisor.""" -from __future__ import annotations - -from pathlib import Path -from typing import Any - -from mlia.core.advice_generation import AdviceCategory -from mlia.core.advice_generation import AdviceProducer -from mlia.core.advisor import DefaultInferenceAdvisor -from mlia.core.advisor import InferenceAdvisor -from mlia.core.context import Context -from mlia.core.context import ExecutionContext -from mlia.core.data_analysis import DataAnalyzer -from mlia.core.data_collection import DataCollector -from mlia.core.events import Event -from mlia.core.typing import PathOrFileLike -from mlia.devices.tosa.advice_generation import TOSAAdviceProducer -from mlia.devices.tosa.config import TOSAConfiguration -from mlia.devices.tosa.data_analysis import TOSADataAnalyzer -from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility -from mlia.devices.tosa.events import TOSAAdvisorStartedEvent -from mlia.devices.tosa.handlers import TOSAEventHandler - - -class TOSAInferenceAdvisor(DefaultInferenceAdvisor): - """TOSA inference advisor.""" - - @classmethod - def name(cls) -> str: - """Return name of the advisor.""" - return "tosa_inference_advisor" - - def get_collectors(self, context: Context) -> list[DataCollector]: - """Return list of the data collectors.""" - model = self.get_model(context) - - collectors: list[DataCollector] = [] - - if AdviceCategory.OPERATORS in context.advice_category: - collectors.append(TOSAOperatorCompatibility(model)) - - return collectors - - def get_analyzers(self, context: Context) -> list[DataAnalyzer]: - """Return list of the data analyzers.""" - return [ - TOSADataAnalyzer(), - ] - - def get_producers(self, context: Context) -> list[AdviceProducer]: - """Return list of the advice producers.""" - return [ - TOSAAdviceProducer(), - ] - - def get_events(self, context: Context) -> list[Event]: - """Return list of the startup events.""" - model = self.get_model(context) - target_profile = self.get_target_profile(context) - - return [ - TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)), - ] - - -def configure_and_get_tosa_advisor( - context: ExecutionContext, - target_profile: str, - model: str | Path, - output: PathOrFileLike | None = None, - **_extra_args: Any, -) -> InferenceAdvisor: - """Create and configure TOSA advisor.""" - if context.event_handlers is None: - context.event_handlers = [TOSAEventHandler(output)] - - if context.config_parameters is None: - context.config_parameters = _get_config_parameters(model, target_profile) - - return TOSAInferenceAdvisor() - - -def _get_config_parameters(model: str | Path, target_profile: str) -> dict[str, Any]: - """Get configuration parameters for the advisor.""" - advisor_parameters: dict[str, Any] = { - "tosa_inference_advisor": { - "model": str(model), - "target_profile": target_profile, - } - } - - return advisor_parameters diff --git a/src/mlia/devices/tosa/config.py b/src/mlia/devices/tosa/config.py deleted file mode 100644 index c3879a7..0000000 --- a/src/mlia/devices/tosa/config.py +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA target configuration.""" -from mlia.devices.config import IPConfiguration -from mlia.utils.filesystem import get_profile - - -class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods - """TOSA configuration.""" - - def __init__(self, target_profile: str) -> None: - """Init configuration.""" - target_data = get_profile(target_profile) - target = target_data["target"] - - if target != "tosa": - raise Exception(f"Wrong target {target} for TOSA configuration") - - super().__init__(target) diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py deleted file mode 100644 index 7cbd61d..0000000 --- a/src/mlia/devices/tosa/data_analysis.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA data analysis module.""" -from dataclasses import dataclass -from functools import singledispatchmethod - -from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo -from mlia.core.common import DataItem -from mlia.core.data_analysis import Fact -from mlia.core.data_analysis import FactExtractor - - -@dataclass -class ModelIsTOSACompatible(Fact): - """Model is completely TOSA compatible.""" - - -@dataclass -class ModelIsNotTOSACompatible(Fact): - """Model is not TOSA compatible.""" - - -class TOSADataAnalyzer(FactExtractor): - """TOSA data analyzer.""" - - @singledispatchmethod - def analyze_data(self, data_item: DataItem) -> None: # type: ignore - """Analyse the data.""" - - @analyze_data.register - def analyze_tosa_compatibility(self, data_item: TOSACompatibilityInfo) -> None: - """Analyse TOSA compatibility information.""" - if data_item.tosa_compatible: - self.add_fact(ModelIsTOSACompatible()) - else: - self.add_fact(ModelIsNotTOSACompatible()) diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py deleted file mode 100644 index 105c501..0000000 --- a/src/mlia/devices/tosa/data_collection.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA data collection module.""" -from pathlib import Path - -from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info -from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo -from mlia.core.data_collection import ContextAwareDataCollector -from mlia.nn.tensorflow.config import get_tflite_model -from mlia.utils.logging import log_action - - -class TOSAOperatorCompatibility(ContextAwareDataCollector): - """Collect operator compatibility information.""" - - def __init__(self, model: Path) -> None: - """Init the data collector.""" - self.model = model - - def collect_data(self) -> TOSACompatibilityInfo: - """Collect TOSA compatibility information.""" - tflite_model = get_tflite_model(self.model, self.context) - - with log_action("Checking operator compatibility ..."): - return get_tosa_compatibility_info(tflite_model.model_path) - - @classmethod - def name(cls) -> str: - """Return name of the collector.""" - return "tosa_operator_compatibility" diff --git a/src/mlia/devices/tosa/events.py b/src/mlia/devices/tosa/events.py deleted file mode 100644 index ceaba57..0000000 --- a/src/mlia/devices/tosa/events.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA advisor events.""" -from dataclasses import dataclass -from pathlib import Path - -from mlia.core.events import Event -from mlia.core.events import EventDispatcher -from mlia.devices.tosa.config import TOSAConfiguration - - -@dataclass -class TOSAAdvisorStartedEvent(Event): - """Event with TOSA advisor parameters.""" - - model: Path - device: TOSAConfiguration - - -class TOSAAdvisorEventHandler(EventDispatcher): - """Event handler for the TOSA inference advisor.""" - - def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None: - """Handle TOSAAdvisorStartedEvent event.""" diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py deleted file mode 100644 index fc82657..0000000 --- a/src/mlia/devices/tosa/handlers.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""TOSA Advisor event handlers.""" -# pylint: disable=R0801 -from __future__ import annotations - -import logging - -from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo -from mlia.core.events import CollectedDataEvent -from mlia.core.handlers import WorkflowEventsHandler -from mlia.core.typing import PathOrFileLike -from mlia.devices.tosa.events import TOSAAdvisorEventHandler -from mlia.devices.tosa.events import TOSAAdvisorStartedEvent -from mlia.devices.tosa.reporters import tosa_formatters - -logger = logging.getLogger(__name__) - - -class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): - """Event handler for TOSA advisor.""" - - def __init__(self, output: PathOrFileLike | None = None) -> None: - """Init event handler.""" - super().__init__(tosa_formatters, output) - - def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None: - """Handle TOSAAdvisorStartedEvent event.""" - self.reporter.submit(event.device) - - def on_collected_data(self, event: CollectedDataEvent) -> None: - """Handle CollectedDataEvent event.""" - data_item = event.data_item - - if isinstance(data_item, TOSACompatibilityInfo): - self.reporter.submit(data_item.operators, delay_print=True) diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py deleted file mode 100644 index b75ceb0..0000000 --- a/src/mlia/devices/tosa/operators.py +++ /dev/null @@ -1,11 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Operators module.""" - - -def report() -> None: - """Generate supported operators report.""" - raise Exception( - "Generating a supported operators report is not " - "currently supported with TOSA target profile." - ) diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py deleted file mode 100644 index e5559ee..0000000 --- a/src/mlia/devices/tosa/reporters.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Reports module.""" -from __future__ import annotations - -from typing import Any -from typing import Callable - -from mlia.backend.tosa_checker.compat import Operator -from mlia.core.advice_generation import Advice -from mlia.core.reporters import report_advice -from mlia.core.reporting import Cell -from mlia.core.reporting import Column -from mlia.core.reporting import Format -from mlia.core.reporting import NestedReport -from mlia.core.reporting import Report -from mlia.core.reporting import ReportItem -from mlia.core.reporting import Table -from mlia.devices.tosa.config import TOSAConfiguration -from mlia.utils.console import style_improvement -from mlia.utils.types import is_list_of - - -def report_device(device: TOSAConfiguration) -> Report: - """Generate report for the device.""" - return NestedReport( - "Device information", - "device", - [ - ReportItem("Target", alias="target", value=device.target), - ], - ) - - -def report_tosa_operators(ops: list[Operator]) -> Report: - """Generate report for the operators.""" - return Table( - [ - Column("#", only_for=["plain_text"]), - Column( - "Operator location", - alias="operator_location", - fmt=Format(wrap_width=30), - ), - Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)), - Column( - "TOSA compatibility", - alias="is_tosa_compatible", - fmt=Format(wrap_width=25), - ), - ], - [ - ( - index + 1, - op.location, - op.name, - Cell( - op.is_tosa_compatible, - Format( - style=style_improvement(op.is_tosa_compatible), - str_fmt=lambda v: "Compatible" if v else "Not compatible", - ), - ), - ) - for index, op in enumerate(ops) - ], - name="Operators", - alias="operators", - ) - - -def tosa_formatters(data: Any) -> Callable[[Any], Report]: - """Find appropriate formatter for the provided data.""" - if is_list_of(data, Advice): - return report_advice - - if isinstance(data, TOSAConfiguration): - return report_device - - if is_list_of(data, Operator): - return report_tosa_operators - - raise Exception(f"Unable to find appropriate formatter for {data}") |