aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/devices
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:50 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commit664d8c55609253e68d153a91514c8fefa00557b1 (patch)
tree4b2a0ecaf30e9151d6b971a24fa6c6104884896f /src/mlia/devices
parenta8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (diff)
downloadmlia-664d8c55609253e68d153a91514c8fefa00557b1.tar.gz
MLIA-549 Integrate TOSA checker into MLIA
- Add new module for TOSA - Add advisor workflow components - Use TOSA checker for getting operators compatibility information Change-Id: I769e5e2a84e15779658f0895b4a347384def63bf
Diffstat (limited to 'src/mlia/devices')
-rw-r--r--src/mlia/devices/tosa/__init__.py3
-rw-r--r--src/mlia/devices/tosa/advice_generation.py40
-rw-r--r--src/mlia/devices/tosa/advisor.py98
-rw-r--r--src/mlia/devices/tosa/config.py19
-rw-r--r--src/mlia/devices/tosa/data_analysis.py36
-rw-r--r--src/mlia/devices/tosa/data_collection.py35
-rw-r--r--src/mlia/devices/tosa/events.py24
-rw-r--r--src/mlia/devices/tosa/handlers.py35
-rw-r--r--src/mlia/devices/tosa/operators.py70
-rw-r--r--src/mlia/devices/tosa/reporters.py94
10 files changed, 454 insertions, 0 deletions
diff --git a/src/mlia/devices/tosa/__init__.py b/src/mlia/devices/tosa/__init__.py
new file mode 100644
index 0000000..762c831
--- /dev/null
+++ b/src/mlia/devices/tosa/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA target module."""
diff --git a/src/mlia/devices/tosa/advice_generation.py b/src/mlia/devices/tosa/advice_generation.py
new file mode 100644
index 0000000..7adfcb9
--- /dev/null
+++ b/src/mlia/devices/tosa/advice_generation.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advice generation."""
+from functools import singledispatchmethod
+
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
+from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
+
+
+class TOSAAdviceProducer(FactBasedAdviceProducer):
+ """TOSA advice producer."""
+
+ @singledispatchmethod
+ def produce_advice(self, _data_item: DataItem) -> None:
+ """Produce advice."""
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_tosa_compatible(
+ self, _data_item: ModelIsTOSACompatible
+ ) -> None:
+ """Advice for TOSA compatibility."""
+ self.add_advice(["Model is fully TOSA compatible."])
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_not_tosa_compatible(
+ self, _data_item: ModelIsNotTOSACompatible
+ ) -> None:
+ """Advice for TOSA compatibility."""
+ self.add_advice(
+ [
+ "Some operators in the model are not TOSA compatible. "
+ "Please, refer to the operators table for more information."
+ ]
+ )
diff --git a/src/mlia/devices/tosa/advisor.py b/src/mlia/devices/tosa/advisor.py
new file mode 100644
index 0000000..6a32b94
--- /dev/null
+++ b/src/mlia/devices/tosa/advisor.py
@@ -0,0 +1,98 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advisor."""
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advice_generation import AdviceCategory
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import DefaultInferenceAdvisor
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.events import Event
+from mlia.devices.tosa.advice_generation import TOSAAdviceProducer
+from mlia.devices.tosa.config import TOSAConfiguration
+from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
+from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility
+from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
+from mlia.devices.tosa.handlers import TOSAEventHandler
+
+
+class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
+ """TOSA inference advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "tosa_inference_advisor"
+
+ def get_collectors(self, context: Context) -> List[DataCollector]:
+ """Return list of the data collectors."""
+ model = self.get_model(context)
+
+ collectors: List[DataCollector] = []
+
+ if AdviceCategory.OPERATORS in context.advice_category:
+ collectors.append(TOSAOperatorCompatibility(model))
+
+ return collectors
+
+ def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ """Return list of the data analyzers."""
+ return [
+ TOSADataAnalyzer(),
+ ]
+
+ def get_producers(self, context: Context) -> List[AdviceProducer]:
+ """Return list of the advice producers."""
+ return [
+ TOSAAdviceProducer(),
+ ]
+
+ def get_events(self, context: Context) -> List[Event]:
+ """Return list of the startup events."""
+ model = self.get_model(context)
+ target_profile = self.get_target_profile(context)
+
+ return [
+ TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)),
+ ]
+
+
+def configure_and_get_tosa_advisor(
+ context: ExecutionContext,
+ target_profile: str,
+ model: Union[Path, str],
+ output: Optional[PathOrFileLike] = None,
+ **_extra_args: Any
+) -> InferenceAdvisor:
+ """Create and configure TOSA advisor."""
+ if context.event_handlers is None:
+ context.event_handlers = [TOSAEventHandler(output)]
+
+ if context.config_parameters is None:
+ context.config_parameters = _get_config_parameters(model, target_profile)
+
+ return TOSAInferenceAdvisor()
+
+
+def _get_config_parameters(
+ model: Union[Path, str], target_profile: str
+) -> Dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: Dict[str, Any] = {
+ "tosa_inference_advisor": {
+ "model": str(model),
+ "target_profile": target_profile,
+ }
+ }
+
+ return advisor_parameters
diff --git a/src/mlia/devices/tosa/config.py b/src/mlia/devices/tosa/config.py
new file mode 100644
index 0000000..c3879a7
--- /dev/null
+++ b/src/mlia/devices/tosa/config.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA target configuration."""
+from mlia.devices.config import IPConfiguration
+from mlia.utils.filesystem import get_profile
+
+
+class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
+ """TOSA configuration."""
+
+ def __init__(self, target_profile: str) -> None:
+ """Init configuration."""
+ target_data = get_profile(target_profile)
+ target = target_data["target"]
+
+ if target != "tosa":
+ raise Exception(f"Wrong target {target} for TOSA configuration")
+
+ super().__init__(target)
diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py
new file mode 100644
index 0000000..aa696a5
--- /dev/null
+++ b/src/mlia/devices/tosa/data_analysis.py
@@ -0,0 +1,36 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA data analysis module."""
+from dataclasses import dataclass
+from functools import singledispatchmethod
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+@dataclass
+class ModelIsTOSACompatible(Fact):
+ """Model is completely TOSA compatible."""
+
+
+@dataclass
+class ModelIsNotTOSACompatible(Fact):
+ """Model is not TOSA compatible."""
+
+
+class TOSADataAnalyzer(FactExtractor):
+ """TOSA data analyzer."""
+
+ @singledispatchmethod
+ def analyze_data(self, data_item: DataItem) -> None:
+ """Analyse the data."""
+
+ @analyze_data.register
+ def analyze_tosa_compatibility(self, data_item: TOSACompatibilityInfo) -> None:
+ """Analyse TOSA compatibility information."""
+ if data_item.tosa_compatible:
+ self.add_fact(ModelIsTOSACompatible())
+ else:
+ self.add_fact(ModelIsNotTOSACompatible())
diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py
new file mode 100644
index 0000000..843d5ab
--- /dev/null
+++ b/src/mlia/devices/tosa/data_collection.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA data collection module."""
+import logging
+from pathlib import Path
+
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.devices.tosa.operators import get_tosa_compatibility_info
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+from mlia.nn.tensorflow.config import get_tflite_model
+
+logger = logging.getLogger(__name__)
+
+
+class TOSAOperatorCompatibility(ContextAwareDataCollector):
+ """Collect operator compatibility information."""
+
+ def __init__(self, model: Path) -> None:
+ """Init the data collector."""
+ self.model = model
+
+ def collect_data(self) -> TOSACompatibilityInfo:
+ """Collect TOSA compatibility information."""
+ tflite_model = get_tflite_model(self.model, self.context)
+
+ logger.info("Checking operator compatibility ...")
+ tosa_info = get_tosa_compatibility_info(tflite_model.model_path)
+ logger.info("Done\n")
+
+ return tosa_info
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "tosa_operator_compatibility"
diff --git a/src/mlia/devices/tosa/events.py b/src/mlia/devices/tosa/events.py
new file mode 100644
index 0000000..ceaba57
--- /dev/null
+++ b/src/mlia/devices/tosa/events.py
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advisor events."""
+from dataclasses import dataclass
+from pathlib import Path
+
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.devices.tosa.config import TOSAConfiguration
+
+
+@dataclass
+class TOSAAdvisorStartedEvent(Event):
+ """Event with TOSA advisor parameters."""
+
+ model: Path
+ device: TOSAConfiguration
+
+
+class TOSAAdvisorEventHandler(EventDispatcher):
+ """Event handler for the TOSA inference advisor."""
+
+ def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
+ """Handle TOSAAdvisorStartedEvent event."""
diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py
new file mode 100644
index 0000000..00c18c5
--- /dev/null
+++ b/src/mlia/devices/tosa/handlers.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA Advisor event handlers."""
+# pylint: disable=R0801
+import logging
+from typing import Optional
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.events import CollectedDataEvent
+from mlia.core.handlers import WorkflowEventsHandler
+from mlia.devices.tosa.events import TOSAAdvisorEventHandler
+from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+from mlia.devices.tosa.reporters import tosa_formatters
+
+logger = logging.getLogger(__name__)
+
+
+class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
+ """Event handler for TOSA advisor."""
+
+ def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
+ """Init event handler."""
+ super().__init__(tosa_formatters, output)
+
+ def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
+ """Handle TOSAAdvisorStartedEvent event."""
+ self.reporter.submit(event.device)
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedDataEvent event."""
+ data_item = event.data_item
+
+ if isinstance(data_item, TOSACompatibilityInfo):
+ self.reporter.submit(data_item.operators, delay_print=True)
diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py
new file mode 100644
index 0000000..4f3df10
--- /dev/null
+++ b/src/mlia/devices/tosa/operators.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Operators module."""
+from dataclasses import dataclass
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Protocol
+
+from mlia.core._typing import PathOrFileLike
+
+
+class TOSAChecker(Protocol):
+ """TOSA checker protocol."""
+
+ def is_tosa_compatible(self) -> bool:
+ """Return true if model is TOSA compatible."""
+
+ def _get_tosa_compatibility_for_ops(self) -> List[Any]:
+ """Return list of operators."""
+
+
+@dataclass
+class Operator:
+ """Operator's TOSA compatibility info."""
+
+ location: str
+ name: str
+ is_tosa_compatible: bool
+
+
+@dataclass
+class TOSACompatibilityInfo:
+ """Models' TOSA compatibility information."""
+
+ tosa_compatible: bool
+ operators: List[Operator]
+
+
+def get_tosa_compatibility_info(
+ tflite_model_path: PathOrFileLike,
+) -> TOSACompatibilityInfo:
+ """Return list of the operators."""
+ checker = get_tosa_checker(tflite_model_path)
+
+ if checker is None:
+ raise Exception(
+ "TOSA checker is not available. "
+ "Please make sure that 'tosa_checker' package is installed: "
+ "pip install mlia[tosa]"
+ )
+
+ ops = [
+ Operator(item.location, item.name, item.is_tosa_compatible)
+ for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access
+ ]
+
+ return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops)
+
+
+def get_tosa_checker(tflite_model_path: PathOrFileLike) -> Optional[TOSAChecker]:
+ """Return instance of the TOSA checker."""
+ try:
+ import tosa_checker as tc # pylint: disable=import-outside-toplevel
+ except ImportError:
+ return None
+
+ checker = tc.TOSAChecker(str(tflite_model_path))
+ return cast(TOSAChecker, checker)
diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py
new file mode 100644
index 0000000..8fba95c
--- /dev/null
+++ b/src/mlia/devices/tosa/reporters.py
@@ -0,0 +1,94 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reports module."""
+from typing import Any
+from typing import Callable
+from typing import List
+
+from mlia.core.advice_generation import Advice
+from mlia.core.reporting import Cell
+from mlia.core.reporting import Column
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import Report
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import Table
+from mlia.devices.tosa.config import TOSAConfiguration
+from mlia.devices.tosa.operators import Operator
+from mlia.utils.console import style_improvement
+from mlia.utils.types import is_list_of
+
+
+def report_device(device: TOSAConfiguration) -> Report:
+ """Generate report for the device."""
+ return NestedReport(
+ "Device information",
+ "device",
+ [
+ ReportItem("Target", alias="target", value=device.target),
+ ],
+ )
+
+
+def report_advice(advice: List[Advice]) -> Report:
+ """Generate report for the advice."""
+ return Table(
+ columns=[
+ Column("#", only_for=["plain_text"]),
+ Column("Advice", alias="advice_message"),
+ ],
+ rows=[(i + 1, a.messages) for i, a in enumerate(advice)],
+ name="Advice",
+ alias="advice",
+ )
+
+
+def report_tosa_operators(ops: List[Operator]) -> Report:
+ """Generate report for the operators."""
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=30),
+ ),
+ Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
+ Column(
+ "TOSA compatibility",
+ alias="is_tosa_compatible",
+ fmt=Format(wrap_width=25),
+ ),
+ ],
+ [
+ (
+ index + 1,
+ op.location,
+ op.name,
+ Cell(
+ op.is_tosa_compatible,
+ Format(
+ style=style_improvement(op.is_tosa_compatible),
+ str_fmt=lambda v: "Compatible" if v else "Not compatible",
+ ),
+ ),
+ )
+ for index, op in enumerate(ops)
+ ],
+ name="Operators",
+ alias="operators",
+ )
+
+
+def tosa_formatters(data: Any) -> Callable[[Any], Report]:
+ """Find appropriate formatter for the provided data."""
+ if is_list_of(data, Advice):
+ return report_advice
+
+ if isinstance(data, TOSAConfiguration):
+ return report_device
+
+ if is_list_of(data, Operator):
+ return report_tosa_operators
+
+ raise Exception(f"Unable to find appropriate formatter for {data}")