aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/tosa
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/tosa')
-rw-r--r--src/mlia/target/tosa/__init__.py3
-rw-r--r--src/mlia/target/tosa/advice_generation.py40
-rw-r--r--src/mlia/target/tosa/advisor.py94
-rw-r--r--src/mlia/target/tosa/config.py19
-rw-r--r--src/mlia/target/tosa/data_analysis.py36
-rw-r--r--src/mlia/target/tosa/data_collection.py30
-rw-r--r--src/mlia/target/tosa/events.py24
-rw-r--r--src/mlia/target/tosa/handlers.py36
-rw-r--r--src/mlia/target/tosa/operators.py11
-rw-r--r--src/mlia/target/tosa/reporters.py83
10 files changed, 376 insertions, 0 deletions
diff --git a/src/mlia/target/tosa/__init__.py b/src/mlia/target/tosa/__init__.py
new file mode 100644
index 0000000..762c831
--- /dev/null
+++ b/src/mlia/target/tosa/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA target module."""
diff --git a/src/mlia/target/tosa/advice_generation.py b/src/mlia/target/tosa/advice_generation.py
new file mode 100644
index 0000000..f531b84
--- /dev/null
+++ b/src/mlia/target/tosa/advice_generation.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advice generation."""
+from functools import singledispatchmethod
+
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.target.tosa.data_analysis import ModelIsNotTOSACompatible
+from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
+
+
+class TOSAAdviceProducer(FactBasedAdviceProducer):
+ """TOSA advice producer."""
+
+ @singledispatchmethod
+ def produce_advice(self, _data_item: DataItem) -> None: # type: ignore
+ """Produce advice."""
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_tosa_compatible(
+ self, _data_item: ModelIsTOSACompatible
+ ) -> None:
+ """Advice for TOSA compatibility."""
+ self.add_advice(["Model is fully TOSA compatible."])
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_not_tosa_compatible(
+ self, _data_item: ModelIsNotTOSACompatible
+ ) -> None:
+ """Advice for TOSA compatibility."""
+ self.add_advice(
+ [
+ "Some operators in the model are not TOSA compatible. "
+ "Please, refer to the operators table for more information."
+ ]
+ )
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
new file mode 100644
index 0000000..2739dfd
--- /dev/null
+++ b/src/mlia/target/tosa/advisor.py
@@ -0,0 +1,94 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advisor."""
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+from mlia.core.advice_generation import AdviceCategory
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import DefaultInferenceAdvisor
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.events import Event
+from mlia.core.typing import PathOrFileLike
+from mlia.target.tosa.advice_generation import TOSAAdviceProducer
+from mlia.target.tosa.config import TOSAConfiguration
+from mlia.target.tosa.data_analysis import TOSADataAnalyzer
+from mlia.target.tosa.data_collection import TOSAOperatorCompatibility
+from mlia.target.tosa.events import TOSAAdvisorStartedEvent
+from mlia.target.tosa.handlers import TOSAEventHandler
+
+
+class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
+ """TOSA inference advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "tosa_inference_advisor"
+
+ def get_collectors(self, context: Context) -> list[DataCollector]:
+ """Return list of the data collectors."""
+ model = self.get_model(context)
+
+ collectors: list[DataCollector] = []
+
+ if AdviceCategory.OPERATORS in context.advice_category:
+ collectors.append(TOSAOperatorCompatibility(model))
+
+ return collectors
+
+ def get_analyzers(self, context: Context) -> list[DataAnalyzer]:
+ """Return list of the data analyzers."""
+ return [
+ TOSADataAnalyzer(),
+ ]
+
+ def get_producers(self, context: Context) -> list[AdviceProducer]:
+ """Return list of the advice producers."""
+ return [
+ TOSAAdviceProducer(),
+ ]
+
+ def get_events(self, context: Context) -> list[Event]:
+ """Return list of the startup events."""
+ model = self.get_model(context)
+ target_profile = self.get_target_profile(context)
+
+ return [
+ TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)),
+ ]
+
+
+def configure_and_get_tosa_advisor(
+ context: ExecutionContext,
+ target_profile: str,
+ model: str | Path,
+ output: PathOrFileLike | None = None,
+ **_extra_args: Any,
+) -> InferenceAdvisor:
+ """Create and configure TOSA advisor."""
+ if context.event_handlers is None:
+ context.event_handlers = [TOSAEventHandler(output)]
+
+ if context.config_parameters is None:
+ context.config_parameters = _get_config_parameters(model, target_profile)
+
+ return TOSAInferenceAdvisor()
+
+
+def _get_config_parameters(model: str | Path, target_profile: str) -> dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: dict[str, Any] = {
+ "tosa_inference_advisor": {
+ "model": str(model),
+ "target_profile": target_profile,
+ }
+ }
+
+ return advisor_parameters
diff --git a/src/mlia/target/tosa/config.py b/src/mlia/target/tosa/config.py
new file mode 100644
index 0000000..22805b7
--- /dev/null
+++ b/src/mlia/target/tosa/config.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA target configuration."""
+from mlia.target.config import IPConfiguration
+from mlia.utils.filesystem import get_profile
+
+
+class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
+ """TOSA configuration."""
+
+ def __init__(self, target_profile: str) -> None:
+ """Init configuration."""
+ target_data = get_profile(target_profile)
+ target = target_data["target"]
+
+ if target != "tosa":
+ raise Exception(f"Wrong target {target} for TOSA configuration")
+
+ super().__init__(target)
diff --git a/src/mlia/target/tosa/data_analysis.py b/src/mlia/target/tosa/data_analysis.py
new file mode 100644
index 0000000..7cbd61d
--- /dev/null
+++ b/src/mlia/target/tosa/data_analysis.py
@@ -0,0 +1,36 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA data analysis module."""
+from dataclasses import dataclass
+from functools import singledispatchmethod
+
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+
+
+@dataclass
+class ModelIsTOSACompatible(Fact):
+ """Model is completely TOSA compatible."""
+
+
+@dataclass
+class ModelIsNotTOSACompatible(Fact):
+ """Model is not TOSA compatible."""
+
+
+class TOSADataAnalyzer(FactExtractor):
+ """TOSA data analyzer."""
+
+ @singledispatchmethod
+ def analyze_data(self, data_item: DataItem) -> None: # type: ignore
+ """Analyse the data."""
+
+ @analyze_data.register
+ def analyze_tosa_compatibility(self, data_item: TOSACompatibilityInfo) -> None:
+ """Analyse TOSA compatibility information."""
+ if data_item.tosa_compatible:
+ self.add_fact(ModelIsTOSACompatible())
+ else:
+ self.add_fact(ModelIsNotTOSACompatible())
diff --git a/src/mlia/target/tosa/data_collection.py b/src/mlia/target/tosa/data_collection.py
new file mode 100644
index 0000000..105c501
--- /dev/null
+++ b/src/mlia/target/tosa/data_collection.py
@@ -0,0 +1,30 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA data collection module."""
+from pathlib import Path
+
+from mlia.backend.tosa_checker.compat import get_tosa_compatibility_info
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.utils.logging import log_action
+
+
+class TOSAOperatorCompatibility(ContextAwareDataCollector):
+ """Collect operator compatibility information."""
+
+ def __init__(self, model: Path) -> None:
+ """Init the data collector."""
+ self.model = model
+
+ def collect_data(self) -> TOSACompatibilityInfo:
+ """Collect TOSA compatibility information."""
+ tflite_model = get_tflite_model(self.model, self.context)
+
+ with log_action("Checking operator compatibility ..."):
+ return get_tosa_compatibility_info(tflite_model.model_path)
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "tosa_operator_compatibility"
diff --git a/src/mlia/target/tosa/events.py b/src/mlia/target/tosa/events.py
new file mode 100644
index 0000000..67d499d
--- /dev/null
+++ b/src/mlia/target/tosa/events.py
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advisor events."""
+from dataclasses import dataclass
+from pathlib import Path
+
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.target.tosa.config import TOSAConfiguration
+
+
+@dataclass
+class TOSAAdvisorStartedEvent(Event):
+ """Event with TOSA advisor parameters."""
+
+ model: Path
+ device: TOSAConfiguration
+
+
+class TOSAAdvisorEventHandler(EventDispatcher):
+ """Event handler for the TOSA inference advisor."""
+
+ def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
+ """Handle TOSAAdvisorStartedEvent event."""
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
new file mode 100644
index 0000000..863558c
--- /dev/null
+++ b/src/mlia/target/tosa/handlers.py
@@ -0,0 +1,36 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA Advisor event handlers."""
+# pylint: disable=R0801
+from __future__ import annotations
+
+import logging
+
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
+from mlia.core.events import CollectedDataEvent
+from mlia.core.handlers import WorkflowEventsHandler
+from mlia.core.typing import PathOrFileLike
+from mlia.target.tosa.events import TOSAAdvisorEventHandler
+from mlia.target.tosa.events import TOSAAdvisorStartedEvent
+from mlia.target.tosa.reporters import tosa_formatters
+
+logger = logging.getLogger(__name__)
+
+
+class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
+ """Event handler for TOSA advisor."""
+
+ def __init__(self, output: PathOrFileLike | None = None) -> None:
+ """Init event handler."""
+ super().__init__(tosa_formatters, output)
+
+ def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
+ """Handle TOSAAdvisorStartedEvent event."""
+ self.reporter.submit(event.device)
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedDataEvent event."""
+ data_item = event.data_item
+
+ if isinstance(data_item, TOSACompatibilityInfo):
+ self.reporter.submit(data_item.operators, delay_print=True)
diff --git a/src/mlia/target/tosa/operators.py b/src/mlia/target/tosa/operators.py
new file mode 100644
index 0000000..b75ceb0
--- /dev/null
+++ b/src/mlia/target/tosa/operators.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Operators module."""
+
+
+def report() -> None:
+ """Generate supported operators report."""
+ raise Exception(
+ "Generating a supported operators report is not "
+ "currently supported with TOSA target profile."
+ )
diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py
new file mode 100644
index 0000000..01fbb97
--- /dev/null
+++ b/src/mlia/target/tosa/reporters.py
@@ -0,0 +1,83 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reports module."""
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+
+from mlia.backend.tosa_checker.compat import Operator
+from mlia.core.advice_generation import Advice
+from mlia.core.reporters import report_advice
+from mlia.core.reporting import Cell
+from mlia.core.reporting import Column
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import Report
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import Table
+from mlia.target.tosa.config import TOSAConfiguration
+from mlia.utils.console import style_improvement
+from mlia.utils.types import is_list_of
+
+
+def report_device(device: TOSAConfiguration) -> Report:
+ """Generate report for the device."""
+ return NestedReport(
+ "Device information",
+ "device",
+ [
+ ReportItem("Target", alias="target", value=device.target),
+ ],
+ )
+
+
+def report_tosa_operators(ops: list[Operator]) -> Report:
+ """Generate report for the operators."""
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=30),
+ ),
+ Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
+ Column(
+ "TOSA compatibility",
+ alias="is_tosa_compatible",
+ fmt=Format(wrap_width=25),
+ ),
+ ],
+ [
+ (
+ index + 1,
+ op.location,
+ op.name,
+ Cell(
+ op.is_tosa_compatible,
+ Format(
+ style=style_improvement(op.is_tosa_compatible),
+ str_fmt=lambda v: "Compatible" if v else "Not compatible",
+ ),
+ ),
+ )
+ for index, op in enumerate(ops)
+ ],
+ name="Operators",
+ alias="operators",
+ )
+
+
+def tosa_formatters(data: Any) -> Callable[[Any], Report]:
+ """Find appropriate formatter for the provided data."""
+ if is_list_of(data, Advice):
+ return report_advice
+
+ if isinstance(data, TOSAConfiguration):
+ return report_device
+
+ if is_list_of(data, Operator):
+ return report_tosa_operators
+
+ raise Exception(f"Unable to find appropriate formatter for {data}")