aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target')
-rw-r--r--src/mlia/target/tosa/advisor.py14
-rw-r--r--src/mlia/target/tosa/events.py6
-rw-r--r--src/mlia/target/tosa/handlers.py3
-rw-r--r--src/mlia/target/tosa/metadata.py8
-rw-r--r--src/mlia/target/tosa/reporters.py93
5 files changed, 120 insertions, 4 deletions
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 4851113..0da44db 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -16,12 +16,16 @@ from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
+from mlia.core.metadata import MLIAMetadata
+from mlia.core.metadata import ModelMetadata
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
from mlia.target.tosa.data_collection import TOSAOperatorCompatibility
from mlia.target.tosa.events import TOSAAdvisorStartedEvent
from mlia.target.tosa.handlers import TOSAEventHandler
+from mlia.target.tosa.metadata import TOSAMetadata
+from mlia.target.tosa.reporters import MetadataDisplay
class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
@@ -61,7 +65,15 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
target_profile = self.get_target_profile(context)
return [
- TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)),
+ TOSAAdvisorStartedEvent(
+ model,
+ TOSAConfiguration(target_profile),
+ MetadataDisplay(
+ TOSAMetadata("tosa-checker"),
+ MLIAMetadata("mlia"),
+ ModelMetadata(model),
+ ),
+ )
]
diff --git a/src/mlia/target/tosa/events.py b/src/mlia/target/tosa/events.py
index 67d499d..cbfd199 100644
--- a/src/mlia/target/tosa/events.py
+++ b/src/mlia/target/tosa/events.py
@@ -1,12 +1,15 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advisor events."""
+from __future__ import annotations
+
from dataclasses import dataclass
from pathlib import Path
from mlia.core.events import Event
from mlia.core.events import EventDispatcher
from mlia.target.tosa.config import TOSAConfiguration
+from mlia.target.tosa.reporters import MetadataDisplay
@dataclass
@@ -15,6 +18,7 @@ class TOSAAdvisorStartedEvent(Event):
model: Path
device: TOSAConfiguration
+ tosa_metadata: MetadataDisplay | None
class TOSAAdvisorEventHandler(EventDispatcher):
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
index 1037ba1..7f80f77 100644
--- a/src/mlia/target/tosa/handlers.py
+++ b/src/mlia/target/tosa/handlers.py
@@ -27,10 +27,11 @@ class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
"""Handle TOSAAdvisorStartedEvent event."""
self.reporter.submit(event.device)
+ self.reporter.submit(event.tosa_metadata)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
data_item = event.data_item
if isinstance(data_item, TOSACompatibilityInfo):
- self.reporter.submit(data_item.operators, delay_print=True)
+ self.reporter.submit(data_item, delay_print=True)
diff --git a/src/mlia/target/tosa/metadata.py b/src/mlia/target/tosa/metadata.py
new file mode 100644
index 0000000..5575207
--- /dev/null
+++ b/src/mlia/target/tosa/metadata.py
@@ -0,0 +1,8 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA package metadata."""
+from mlia.core.metadata import Metadata
+
+
+class TOSAMetadata(Metadata): # pylint: disable=too-few-public-methods
+ """TOSA metadata."""
diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py
index 01fbb97..283f61f 100644
--- a/src/mlia/target/tosa/reporters.py
+++ b/src/mlia/target/tosa/reporters.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Reports module."""
from __future__ import annotations
@@ -7,20 +7,41 @@ from typing import Any
from typing import Callable
from mlia.backend.tosa_checker.compat import Operator
+from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.advice_generation import Advice
+from mlia.core.metadata import MLIAMetadata
+from mlia.core.metadata import ModelMetadata
from mlia.core.reporters import report_advice
from mlia.core.reporting import Cell
from mlia.core.reporting import Column
+from mlia.core.reporting import CompoundReport
from mlia.core.reporting import Format
from mlia.core.reporting import NestedReport
from mlia.core.reporting import Report
from mlia.core.reporting import ReportItem
from mlia.core.reporting import Table
from mlia.target.tosa.config import TOSAConfiguration
+from mlia.target.tosa.metadata import TOSAMetadata
from mlia.utils.console import style_improvement
from mlia.utils.types import is_list_of
+class MetadataDisplay: # pylint: disable=too-few-public-methods
+ """TOSA metadata."""
+
+ def __init__(
+ self,
+ tosa_meta: TOSAMetadata,
+ mlia_meta: MLIAMetadata,
+ model_meta: ModelMetadata,
+ ) -> None:
+ """Init TOSAMetadata."""
+ self.tosa_version = tosa_meta.version
+ self.mlia_version = mlia_meta.version
+ self.model_check_sum = model_meta.checksum
+ self.model_name = model_meta.model_name
+
+
def report_device(device: TOSAConfiguration) -> Report:
"""Generate report for the device."""
return NestedReport(
@@ -32,6 +53,34 @@ def report_device(device: TOSAConfiguration) -> Report:
)
+def report_metadata(data: MetadataDisplay) -> Report:
+ """Generate report for the package version."""
+ return NestedReport(
+ "Metadata",
+ "metadata",
+ [
+ ReportItem(
+ "TOSA checker",
+ alias="tosa-checker",
+ nested_items=[ReportItem("version", "version", data.tosa_version)],
+ ),
+ ReportItem(
+ "MLIA",
+ "MLIA",
+ nested_items=[ReportItem("version", "version", data.mlia_version)],
+ ),
+ ReportItem(
+ "Model",
+ "Model",
+ nested_items=[
+ ReportItem("name", "name", data.model_name),
+ ReportItem("checksum", "checksum", data.model_check_sum),
+ ],
+ ),
+ ],
+ )
+
+
def report_tosa_operators(ops: list[Operator]) -> Report:
"""Generate report for the operators."""
return Table(
@@ -69,6 +118,42 @@ def report_tosa_operators(ops: list[Operator]) -> Report:
)
+def report_tosa_exception(exc: Exception | None) -> Report:
+ """Generate report for exception thrown by tosa."""
+ return NestedReport(
+ "TOSA exception",
+ "exception",
+ [
+ ReportItem("TOSA exception", alias="exception", value=repr(exc)),
+ ],
+ )
+
+
+def report_tosa_errors(err: list[str] | None) -> Report:
+ """Generate report for errors thrown by tosa."""
+ message = "".join(err) if err else None
+ return NestedReport(
+ "TOSA stderr",
+ "stderr",
+ [
+ ReportItem(
+ "TOSA stderr",
+ alias="stderr",
+ value=message,
+ ),
+ ],
+ )
+
+
+def report_tosa_compatibility(compat_info: TOSACompatibilityInfo) -> Report:
+ """Generate combined report for all compatibility info."""
+ report_ops = report_tosa_operators(compat_info.operators)
+ report_exception = report_tosa_exception(compat_info.exception)
+
+ report_errors = report_tosa_errors(compat_info.errors)
+ return CompoundReport([report_ops, report_exception, report_errors])
+
+
def tosa_formatters(data: Any) -> Callable[[Any], Report]:
"""Find appropriate formatter for the provided data."""
if is_list_of(data, Advice):
@@ -77,7 +162,13 @@ def tosa_formatters(data: Any) -> Callable[[Any], Report]:
if isinstance(data, TOSAConfiguration):
return report_device
+ if isinstance(data, MetadataDisplay):
+ return report_metadata
+
if is_list_of(data, Operator):
return report_tosa_operators
+ if isinstance(data, TOSACompatibilityInfo):
+ return report_tosa_compatibility
+
raise Exception(f"Unable to find appropriate formatter for {data}")