aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/devices/tosa
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-18 17:21:09 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-11-29 14:44:13 +0000
commit6a88ee5315b4ce5b023370c1e55e48bf9f2b6f67 (patch)
tree88edabf90228724f4fe2944b0ab23859d824a880 /src/mlia/devices/tosa
parenta34163c9d9a5cc0416bcaea2ebf8383bda9d505c (diff)
downloadmlia-6a88ee5315b4ce5b023370c1e55e48bf9f2b6f67.tar.gz
Rename modules
- Rename module "mlia.devices" into "mlia.target" - Rename module "mlia.target.ethosu" into "mlia.target.ethos_u" - Rename module "mlia.target.cortexa" into "mlia.target.cortex_a" - Rename and update tests Change-Id: I6dca7c8646d881f739fb6b5914d1cc7e45e63dc2
Diffstat (limited to 'src/mlia/devices/tosa')
-rw-r--r--src/mlia/devices/tosa/__init__.py3
-rw-r--r--src/mlia/devices/tosa/advice_generation.py40
-rw-r--r--src/mlia/devices/tosa/advisor.py94
-rw-r--r--src/mlia/devices/tosa/config.py19
-rw-r--r--src/mlia/devices/tosa/data_analysis.py36
-rw-r--r--src/mlia/devices/tosa/data_collection.py30
-rw-r--r--src/mlia/devices/tosa/events.py24
-rw-r--r--src/mlia/devices/tosa/handlers.py36
-rw-r--r--src/mlia/devices/tosa/operators.py11
-rw-r--r--src/mlia/devices/tosa/reporters.py83
10 files changed, 0 insertions, 376 deletions
diff --git a/src/mlia/devices/tosa/__init__.py b/src/mlia/devices/tosa/__init__.py
deleted file mode 100644
index 762c831..0000000
--- a/src/mlia/devices/tosa/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA target module."""
diff --git a/src/mlia/devices/tosa/advice_generation.py b/src/mlia/devices/tosa/advice_generation.py
deleted file mode 100644
index a3d8011..0000000
--- a/src/mlia/devices/tosa/advice_generation.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA advice generation."""
-from functools import singledispatchmethod
-
-from mlia.core.advice_generation import advice_category
-from mlia.core.advice_generation import FactBasedAdviceProducer
-from mlia.core.common import AdviceCategory
-from mlia.core.common import DataItem
-from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
-from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
-
-
-class TOSAAdviceProducer(FactBasedAdviceProducer):
- """TOSA advice producer."""
-
- @singledispatchmethod
- def produce_advice(self, _data_item: DataItem) -> None: # type: ignore
- """Produce advice."""
-
- @produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
- def handle_model_is_tosa_compatible(
- self, _data_item: ModelIsTOSACompatible
- ) -> None:
- """Advice for TOSA compatibility."""
- self.add_advice(["Model is fully TOSA compatible."])
-
- @produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
- def handle_model_is_not_tosa_compatible(
- self, _data_item: ModelIsNotTOSACompatible
- ) -> None:
- """Advice for TOSA compatibility."""
- self.add_advice(
- [
- "Some operators in the model are not TOSA compatible. "
- "Please, refer to the operators table for more information."
- ]
- )
diff --git a/src/mlia/devices/tosa/advisor.py b/src/mlia/devices/tosa/advisor.py
deleted file mode 100644
index 53dfa87..0000000
--- a/src/mlia/devices/tosa/advisor.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA advisor."""
-from __future__ import annotations
-
-from pathlib import Path
-from typing import Any
-
-from mlia.core.advice_generation import AdviceCategory
-from mlia.core.advice_generation import AdviceProducer
-from mlia.core.advisor import DefaultInferenceAdvisor
-from mlia.core.advisor import InferenceAdvisor
-from mlia.core.context import Context
-from mlia.core.context import ExecutionContext
-from mlia.core.data_analysis import DataAnalyzer
-from mlia.core.data_collection import DataCollector
-from mlia.core.events import Event
-from mlia.core.typing import PathOrFileLike
-from mlia.devices.tosa.advice_generation import TOSAAdviceProducer
-from mlia.devices.tosa.config import TOSAConfiguration
-from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
-from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility
-from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
-from mlia.devices.tosa.handlers import TOSAEventHandler
-
-
-class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
- """TOSA inference advisor."""
-
- @classmethod
- def name(cls) -> str:
- """Return name of the advisor."""
- return "tosa_inference_advisor"
-
- def get_collectors(self, context: Context) -> list[DataCollector]:
- """Return list of the data collectors."""
- model = self.get_model(context)
-
- collectors: list[DataCollector] = []
-
- if AdviceCategory.OPERATORS in context.advice_category:
- collectors.append(TOSAOperatorCompatibility(model))
-
- return collectors
-
- def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
- """Return list of the data analyzers."""
- return [
- TOSADataAnalyzer(),
- ]
-
- def get_producers(self, context: Context) -> list[AdviceProducer]:
- """Return list of the advice producers."""
- return [
- TOSAAdviceProducer(),
- ]
-
- def get_events(self, context: Context) -> list[Event]:
- """Return list of the startup events."""
- model = self.get_model(context)
- target_profile = self.get_target_profile(context)
-
- return [
- TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)),
- ]
-
-
-def configure_and_get_tosa_advisor(
- context: ExecutionContext,
- target_profile: str,
- model: str | Path,
- output: PathOrFileLike | None = None,
- **_extra_args: Any,
-) -> InferenceAdvisor:
- """Create and configure TOSA advisor."""
- if context.event_handlers is None:
- context.event_handlers = [TOSAEventHandler(output)]
-
- if context.config_parameters is None:
- context.config_parameters = _get_config_parameters(model, target_profile)
-
- return TOSAInferenceAdvisor()
-
-
-def _get_config_parameters(model: str | Path, target_profile: str) -> dict[str, Any]:
- """Get configuration parameters for the advisor."""
- advisor_parameters: dict[str, Any] = {
- "tosa_inference_advisor": {
- "model": str(model),
- "target_profile": target_profile,
- }
- }
-
- return advisor_parameters
diff --git a/src/mlia/devices/tosa/config.py b/src/mlia/devices/tosa/config.py
deleted file mode 100644
index c3879a7..0000000
--- a/src/mlia/devices/tosa/config.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA target configuration."""
-from mlia.devices.config import IPConfiguration
-from mlia.utils.filesystem import get_profile
-
-
-class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
- """TOSA configuration."""
-
- def __init__(self, target_profile: str) -> None:
- """Init configuration."""
- target_data = get_profile(target_profile)
- target = target_data["target"]
-
- if target != "tosa":
- raise Exception(f"Wrong target {target} for TOSA configuration")
-
- super().__init__(target)
diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py
deleted file mode 100644
index 7cbd61d..0000000
--- a/src/mlia/devices/tosa/data_analysis.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA data analysis module."""
-from dataclasses import dataclass
-from functools import singledispatchmethod
-
-from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
-from mlia.core.common import DataItem
-from mlia.core.data_analysis import Fact
-from mlia.core.data_analysis import FactExtractor
-
-
-@dataclass
-class ModelIsTOSACompatible(Fact):
- """Model is completely TOSA compatible."""
-
-
-@dataclass
-class ModelIsNotTOSACompatible(Fact):
- """Model is not TOSA compatible."""
-
-
-class TOSADataAnalyzer(FactExtractor):
- """TOSA data analyzer."""
-
- @singledispatchmethod
- def analyze_data(self, data_item: DataItem) -> None: # type: ignore
- """Analyse the data."""
-
- @analyze_data.register
- def analyze_tosa_compatibility(self, data_item: TOSACompatibilityInfo) -> None:
- """Analyse TOSA compatibility information."""
- if data_item.tosa_compatible:
- self.add_fact(ModelIsTOSACompatible())
- else:
- self.add_fact(ModelIsNotTOSACompatible())
diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py
deleted file mode 100644
index 105c501..0000000
--- a/src/mlia/devices/tosa/data_collection.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA data collection module."""
-from pathlib import Path
-
-from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info
-from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
-from mlia.core.data_collection import ContextAwareDataCollector
-from mlia.nn.tensorflow.config import get_tflite_model
-from mlia.utils.logging import log_action
-
-
-class TOSAOperatorCompatibility(ContextAwareDataCollector):
- """Collect operator compatibility information."""
-
- def __init__(self, model: Path) -> None:
- """Init the data collector."""
- self.model = model
-
- def collect_data(self) -> TOSACompatibilityInfo:
- """Collect TOSA compatibility information."""
- tflite_model = get_tflite_model(self.model, self.context)
-
- with log_action("Checking operator compatibility ..."):
- return get_tosa_compatibility_info(tflite_model.model_path)
-
- @classmethod
- def name(cls) -> str:
- """Return name of the collector."""
- return "tosa_operator_compatibility"
diff --git a/src/mlia/devices/tosa/events.py b/src/mlia/devices/tosa/events.py
deleted file mode 100644
index ceaba57..0000000
--- a/src/mlia/devices/tosa/events.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA advisor events."""
-from dataclasses import dataclass
-from pathlib import Path
-
-from mlia.core.events import Event
-from mlia.core.events import EventDispatcher
-from mlia.devices.tosa.config import TOSAConfiguration
-
-
-@dataclass
-class TOSAAdvisorStartedEvent(Event):
- """Event with TOSA advisor parameters."""
-
- model: Path
- device: TOSAConfiguration
-
-
-class TOSAAdvisorEventHandler(EventDispatcher):
- """Event handler for the TOSA inference advisor."""
-
- def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
- """Handle TOSAAdvisorStartedEvent event."""
diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py
deleted file mode 100644
index fc82657..0000000
--- a/src/mlia/devices/tosa/handlers.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""TOSA Advisor event handlers."""
-# pylint: disable=R0801
-from __future__ import annotations
-
-import logging
-
-from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
-from mlia.core.events import CollectedDataEvent
-from mlia.core.handlers import WorkflowEventsHandler
-from mlia.core.typing import PathOrFileLike
-from mlia.devices.tosa.events import TOSAAdvisorEventHandler
-from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
-from mlia.devices.tosa.reporters import tosa_formatters
-
-logger = logging.getLogger(__name__)
-
-
-class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
- """Event handler for TOSA advisor."""
-
- def __init__(self, output: PathOrFileLike | None = None) -> None:
- """Init event handler."""
- super().__init__(tosa_formatters, output)
-
- def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
- """Handle TOSAAdvisorStartedEvent event."""
- self.reporter.submit(event.device)
-
- def on_collected_data(self, event: CollectedDataEvent) -> None:
- """Handle CollectedDataEvent event."""
- data_item = event.data_item
-
- if isinstance(data_item, TOSACompatibilityInfo):
- self.reporter.submit(data_item.operators, delay_print=True)
diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py
deleted file mode 100644
index b75ceb0..0000000
--- a/src/mlia/devices/tosa/operators.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Operators module."""
-
-
-def report() -> None:
- """Generate supported operators report."""
- raise Exception(
- "Generating a supported operators report is not "
- "currently supported with TOSA target profile."
- )
diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py
deleted file mode 100644
index e5559ee..0000000
--- a/src/mlia/devices/tosa/reporters.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Reports module."""
-from __future__ import annotations
-
-from typing import Any
-from typing import Callable
-
-from mlia.backend.tosa_checker.compat import Operator
-from mlia.core.advice_generation import Advice
-from mlia.core.reporters import report_advice
-from mlia.core.reporting import Cell
-from mlia.core.reporting import Column
-from mlia.core.reporting import Format
-from mlia.core.reporting import NestedReport
-from mlia.core.reporting import Report
-from mlia.core.reporting import ReportItem
-from mlia.core.reporting import Table
-from mlia.devices.tosa.config import TOSAConfiguration
-from mlia.utils.console import style_improvement
-from mlia.utils.types import is_list_of
-
-
-def report_device(device: TOSAConfiguration) -> Report:
- """Generate report for the device."""
- return NestedReport(
- "Device information",
- "device",
- [
- ReportItem("Target", alias="target", value=device.target),
- ],
- )
-
-
-def report_tosa_operators(ops: list[Operator]) -> Report:
- """Generate report for the operators."""
- return Table(
- [
- Column("#", only_for=["plain_text"]),
- Column(
- "Operator location",
- alias="operator_location",
- fmt=Format(wrap_width=30),
- ),
- Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
- Column(
- "TOSA compatibility",
- alias="is_tosa_compatible",
- fmt=Format(wrap_width=25),
- ),
- ],
- [
- (
- index + 1,
- op.location,
- op.name,
- Cell(
- op.is_tosa_compatible,
- Format(
- style=style_improvement(op.is_tosa_compatible),
- str_fmt=lambda v: "Compatible" if v else "Not compatible",
- ),
- ),
- )
- for index, op in enumerate(ops)
- ],
- name="Operators",
- alias="operators",
- )
-
-
-def tosa_formatters(data: Any) -> Callable[[Any], Report]:
- """Find appropriate formatter for the provided data."""
- if is_list_of(data, Advice):
- return report_advice
-
- if isinstance(data, TOSAConfiguration):
- return report_device
-
- if is_list_of(data, Operator):
- return report_tosa_operators
-
- raise Exception(f"Unable to find appropriate formatter for {data}")