aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:50 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commit664d8c55609253e68d153a91514c8fefa00557b1 (patch)
tree4b2a0ecaf30e9151d6b971a24fa6c6104884896f
parenta8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (diff)
downloadmlia-664d8c55609253e68d153a91514c8fefa00557b1.tar.gz
MLIA-549 Integrate TOSA checker into MLIA
- Add new module for TOSA - Add advisor workflow components - Use TOSA checker for getting operators compatibility information Change-Id: I769e5e2a84e15779658f0895b4a347384def63bf
-rw-r--r--src/mlia/api.py2
-rw-r--r--src/mlia/devices/tosa/__init__.py3
-rw-r--r--src/mlia/devices/tosa/advice_generation.py40
-rw-r--r--src/mlia/devices/tosa/advisor.py98
-rw-r--r--src/mlia/devices/tosa/config.py19
-rw-r--r--src/mlia/devices/tosa/data_analysis.py36
-rw-r--r--src/mlia/devices/tosa/data_collection.py35
-rw-r--r--src/mlia/devices/tosa/events.py24
-rw-r--r--src/mlia/devices/tosa/handlers.py35
-rw-r--r--src/mlia/devices/tosa/operators.py70
-rw-r--r--src/mlia/devices/tosa/reporters.py94
-rw-r--r--src/mlia/resources/profiles.json3
-rw-r--r--tests/test_api.py4
-rw-r--r--tests/test_devices_tosa_advice_generation.py56
-rw-r--r--tests/test_devices_tosa_advisor.py29
-rw-r--r--tests/test_devices_tosa_data_analysis.py33
-rw-r--r--tests/test_devices_tosa_data_collection.py28
-rw-r--r--tests/test_devices_tosa_operators.py84
-rw-r--r--tests/test_utils_filesystem.py2
19 files changed, 695 insertions, 0 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 024bc98..c720b8d 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -15,6 +15,7 @@ from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor
+from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor
from mlia.utils.filesystem import get_target
@@ -104,6 +105,7 @@ def get_advisor(
target_factories = {
"ethos-u55": configure_and_get_ethosu_advisor,
"ethos-u65": configure_and_get_ethosu_advisor,
+ "tosa": configure_and_get_tosa_advisor,
}
try:
diff --git a/src/mlia/devices/tosa/__init__.py b/src/mlia/devices/tosa/__init__.py
new file mode 100644
index 0000000..762c831
--- /dev/null
+++ b/src/mlia/devices/tosa/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA target module."""
diff --git a/src/mlia/devices/tosa/advice_generation.py b/src/mlia/devices/tosa/advice_generation.py
new file mode 100644
index 0000000..7adfcb9
--- /dev/null
+++ b/src/mlia/devices/tosa/advice_generation.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advice generation."""
+from functools import singledispatchmethod
+
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
+from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
+
+
+class TOSAAdviceProducer(FactBasedAdviceProducer):
+ """TOSA advice producer."""
+
+ @singledispatchmethod
+ def produce_advice(self, _data_item: DataItem) -> None:
+ """Produce advice."""
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_tosa_compatible(
+ self, _data_item: ModelIsTOSACompatible
+ ) -> None:
+ """Advice for TOSA compatibility."""
+ self.add_advice(["Model is fully TOSA compatible."])
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ def handle_model_is_not_tosa_compatible(
+ self, _data_item: ModelIsNotTOSACompatible
+ ) -> None:
+ """Advice for TOSA compatibility."""
+ self.add_advice(
+ [
+ "Some operators in the model are not TOSA compatible. "
+ "Please, refer to the operators table for more information."
+ ]
+ )
diff --git a/src/mlia/devices/tosa/advisor.py b/src/mlia/devices/tosa/advisor.py
new file mode 100644
index 0000000..6a32b94
--- /dev/null
+++ b/src/mlia/devices/tosa/advisor.py
@@ -0,0 +1,98 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advisor."""
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advice_generation import AdviceCategory
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import DefaultInferenceAdvisor
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.events import Event
+from mlia.devices.tosa.advice_generation import TOSAAdviceProducer
+from mlia.devices.tosa.config import TOSAConfiguration
+from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
+from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility
+from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
+from mlia.devices.tosa.handlers import TOSAEventHandler
+
+
+class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
+ """TOSA inference advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "tosa_inference_advisor"
+
+ def get_collectors(self, context: Context) -> List[DataCollector]:
+ """Return list of the data collectors."""
+ model = self.get_model(context)
+
+ collectors: List[DataCollector] = []
+
+ if AdviceCategory.OPERATORS in context.advice_category:
+ collectors.append(TOSAOperatorCompatibility(model))
+
+ return collectors
+
+ def get_analyzers(self, context: Context) -> List[DataAnalyzer]:
+ """Return list of the data analyzers."""
+ return [
+ TOSADataAnalyzer(),
+ ]
+
+ def get_producers(self, context: Context) -> List[AdviceProducer]:
+ """Return list of the advice producers."""
+ return [
+ TOSAAdviceProducer(),
+ ]
+
+ def get_events(self, context: Context) -> List[Event]:
+ """Return list of the startup events."""
+ model = self.get_model(context)
+ target_profile = self.get_target_profile(context)
+
+ return [
+ TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)),
+ ]
+
+
+def configure_and_get_tosa_advisor(
+ context: ExecutionContext,
+ target_profile: str,
+ model: Union[Path, str],
+ output: Optional[PathOrFileLike] = None,
+ **_extra_args: Any
+) -> InferenceAdvisor:
+ """Create and configure TOSA advisor."""
+ if context.event_handlers is None:
+ context.event_handlers = [TOSAEventHandler(output)]
+
+ if context.config_parameters is None:
+ context.config_parameters = _get_config_parameters(model, target_profile)
+
+ return TOSAInferenceAdvisor()
+
+
+def _get_config_parameters(
+ model: Union[Path, str], target_profile: str
+) -> Dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: Dict[str, Any] = {
+ "tosa_inference_advisor": {
+ "model": str(model),
+ "target_profile": target_profile,
+ }
+ }
+
+ return advisor_parameters
diff --git a/src/mlia/devices/tosa/config.py b/src/mlia/devices/tosa/config.py
new file mode 100644
index 0000000..c3879a7
--- /dev/null
+++ b/src/mlia/devices/tosa/config.py
@@ -0,0 +1,19 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA target configuration."""
+from mlia.devices.config import IPConfiguration
+from mlia.utils.filesystem import get_profile
+
+
+class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods
+ """TOSA configuration."""
+
+ def __init__(self, target_profile: str) -> None:
+ """Init configuration."""
+ target_data = get_profile(target_profile)
+ target = target_data["target"]
+
+ if target != "tosa":
+ raise Exception(f"Wrong target {target} for TOSA configuration")
+
+ super().__init__(target)
diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py
new file mode 100644
index 0000000..aa696a5
--- /dev/null
+++ b/src/mlia/devices/tosa/data_analysis.py
@@ -0,0 +1,36 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA data analysis module."""
+from dataclasses import dataclass
+from functools import singledispatchmethod
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+@dataclass
+class ModelIsTOSACompatible(Fact):
+ """Model is completely TOSA compatible."""
+
+
+@dataclass
+class ModelIsNotTOSACompatible(Fact):
+ """Model is not TOSA compatible."""
+
+
+class TOSADataAnalyzer(FactExtractor):
+ """TOSA data analyzer."""
+
+ @singledispatchmethod
+ def analyze_data(self, data_item: DataItem) -> None:
+ """Analyse the data."""
+
+ @analyze_data.register
+ def analyze_tosa_compatibility(self, data_item: TOSACompatibilityInfo) -> None:
+ """Analyse TOSA compatibility information."""
+ if data_item.tosa_compatible:
+ self.add_fact(ModelIsTOSACompatible())
+ else:
+ self.add_fact(ModelIsNotTOSACompatible())
diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py
new file mode 100644
index 0000000..843d5ab
--- /dev/null
+++ b/src/mlia/devices/tosa/data_collection.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA data collection module."""
+import logging
+from pathlib import Path
+
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.devices.tosa.operators import get_tosa_compatibility_info
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+from mlia.nn.tensorflow.config import get_tflite_model
+
+logger = logging.getLogger(__name__)
+
+
+class TOSAOperatorCompatibility(ContextAwareDataCollector):
+ """Collect operator compatibility information."""
+
+ def __init__(self, model: Path) -> None:
+ """Init the data collector."""
+ self.model = model
+
+ def collect_data(self) -> TOSACompatibilityInfo:
+ """Collect TOSA compatibility information."""
+ tflite_model = get_tflite_model(self.model, self.context)
+
+ logger.info("Checking operator compatibility ...")
+ tosa_info = get_tosa_compatibility_info(tflite_model.model_path)
+ logger.info("Done\n")
+
+ return tosa_info
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "tosa_operator_compatibility"
diff --git a/src/mlia/devices/tosa/events.py b/src/mlia/devices/tosa/events.py
new file mode 100644
index 0000000..ceaba57
--- /dev/null
+++ b/src/mlia/devices/tosa/events.py
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA advisor events."""
+from dataclasses import dataclass
+from pathlib import Path
+
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.devices.tosa.config import TOSAConfiguration
+
+
+@dataclass
+class TOSAAdvisorStartedEvent(Event):
+ """Event with TOSA advisor parameters."""
+
+ model: Path
+ device: TOSAConfiguration
+
+
+class TOSAAdvisorEventHandler(EventDispatcher):
+ """Event handler for the TOSA inference advisor."""
+
+ def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
+ """Handle TOSAAdvisorStartedEvent event."""
diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py
new file mode 100644
index 0000000..00c18c5
--- /dev/null
+++ b/src/mlia/devices/tosa/handlers.py
@@ -0,0 +1,35 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TOSA Advisor event handlers."""
+# pylint: disable=R0801
+import logging
+from typing import Optional
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.events import CollectedDataEvent
+from mlia.core.handlers import WorkflowEventsHandler
+from mlia.devices.tosa.events import TOSAAdvisorEventHandler
+from mlia.devices.tosa.events import TOSAAdvisorStartedEvent
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+from mlia.devices.tosa.reporters import tosa_formatters
+
+logger = logging.getLogger(__name__)
+
+
+class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
+ """Event handler for TOSA advisor."""
+
+ def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
+ """Init event handler."""
+ super().__init__(tosa_formatters, output)
+
+ def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
+ """Handle TOSAAdvisorStartedEvent event."""
+ self.reporter.submit(event.device)
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedDataEvent event."""
+ data_item = event.data_item
+
+ if isinstance(data_item, TOSACompatibilityInfo):
+ self.reporter.submit(data_item.operators, delay_print=True)
diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py
new file mode 100644
index 0000000..4f3df10
--- /dev/null
+++ b/src/mlia/devices/tosa/operators.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Operators module."""
+from dataclasses import dataclass
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Protocol
+
+from mlia.core._typing import PathOrFileLike
+
+
+class TOSAChecker(Protocol):
+ """TOSA checker protocol."""
+
+ def is_tosa_compatible(self) -> bool:
+ """Return true if model is TOSA compatible."""
+
+ def _get_tosa_compatibility_for_ops(self) -> List[Any]:
+ """Return list of operators."""
+
+
+@dataclass
+class Operator:
+ """Operator's TOSA compatibility info."""
+
+ location: str
+ name: str
+ is_tosa_compatible: bool
+
+
+@dataclass
+class TOSACompatibilityInfo:
+ """Models' TOSA compatibility information."""
+
+ tosa_compatible: bool
+ operators: List[Operator]
+
+
+def get_tosa_compatibility_info(
+ tflite_model_path: PathOrFileLike,
+) -> TOSACompatibilityInfo:
+ """Return list of the operators."""
+ checker = get_tosa_checker(tflite_model_path)
+
+ if checker is None:
+ raise Exception(
+ "TOSA checker is not available. "
+ "Please make sure that 'tosa_checker' package is installed: "
+ "pip install mlia[tosa]"
+ )
+
+ ops = [
+ Operator(item.location, item.name, item.is_tosa_compatible)
+ for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access
+ ]
+
+ return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops)
+
+
+def get_tosa_checker(tflite_model_path: PathOrFileLike) -> Optional[TOSAChecker]:
+ """Return instance of the TOSA checker."""
+ try:
+ import tosa_checker as tc # pylint: disable=import-outside-toplevel
+ except ImportError:
+ return None
+
+ checker = tc.TOSAChecker(str(tflite_model_path))
+ return cast(TOSAChecker, checker)
diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py
new file mode 100644
index 0000000..8fba95c
--- /dev/null
+++ b/src/mlia/devices/tosa/reporters.py
@@ -0,0 +1,94 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reports module."""
+from typing import Any
+from typing import Callable
+from typing import List
+
+from mlia.core.advice_generation import Advice
+from mlia.core.reporting import Cell
+from mlia.core.reporting import Column
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import Report
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import Table
+from mlia.devices.tosa.config import TOSAConfiguration
+from mlia.devices.tosa.operators import Operator
+from mlia.utils.console import style_improvement
+from mlia.utils.types import is_list_of
+
+
+def report_device(device: TOSAConfiguration) -> Report:
+ """Generate report for the device."""
+ return NestedReport(
+ "Device information",
+ "device",
+ [
+ ReportItem("Target", alias="target", value=device.target),
+ ],
+ )
+
+
+def report_advice(advice: List[Advice]) -> Report:
+ """Generate report for the advice."""
+ return Table(
+ columns=[
+ Column("#", only_for=["plain_text"]),
+ Column("Advice", alias="advice_message"),
+ ],
+ rows=[(i + 1, a.messages) for i, a in enumerate(advice)],
+ name="Advice",
+ alias="advice",
+ )
+
+
+def report_tosa_operators(ops: List[Operator]) -> Report:
+ """Generate report for the operators."""
+ return Table(
+ [
+ Column("#", only_for=["plain_text"]),
+ Column(
+ "Operator location",
+ alias="operator_location",
+ fmt=Format(wrap_width=30),
+ ),
+ Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)),
+ Column(
+ "TOSA compatibility",
+ alias="is_tosa_compatible",
+ fmt=Format(wrap_width=25),
+ ),
+ ],
+ [
+ (
+ index + 1,
+ op.location,
+ op.name,
+ Cell(
+ op.is_tosa_compatible,
+ Format(
+ style=style_improvement(op.is_tosa_compatible),
+ str_fmt=lambda v: "Compatible" if v else "Not compatible",
+ ),
+ ),
+ )
+ for index, op in enumerate(ops)
+ ],
+ name="Operators",
+ alias="operators",
+ )
+
+
+def tosa_formatters(data: Any) -> Callable[[Any], Report]:
+ """Find appropriate formatter for the provided data."""
+ if is_list_of(data, Advice):
+ return report_advice
+
+ if isinstance(data, TOSAConfiguration):
+ return report_device
+
+ if is_list_of(data, Operator):
+ return report_tosa_operators
+
+ raise Exception(f"Unable to find appropriate formatter for {data}")
diff --git a/src/mlia/resources/profiles.json b/src/mlia/resources/profiles.json
index 4493d7b..b2a3351 100644
--- a/src/mlia/resources/profiles.json
+++ b/src/mlia/resources/profiles.json
@@ -16,5 +16,8 @@
"mac": 512,
"system_config": "Ethos_U65_High_End",
"memory_mode": "Dedicated_Sram"
+ },
+ "tosa": {
+ "target": "tosa"
}
}
diff --git a/tests/test_api.py b/tests/test_api.py
index e8df7af..7b567bf 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -12,6 +12,7 @@ from mlia.core.common import AdviceCategory
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
+from mlia.devices.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
@@ -103,3 +104,6 @@ def test_get_advisor(
ExecutionContext(), "ethos-u55-256", str(test_keras_model)
)
assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor)
+
+ tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model))
+ assert isinstance(tosa_advisor, TOSAInferenceAdvisor)
diff --git a/tests/test_devices_tosa_advice_generation.py b/tests/test_devices_tosa_advice_generation.py
new file mode 100644
index 0000000..018ba57
--- /dev/null
+++ b/tests/test_devices_tosa_advice_generation.py
@@ -0,0 +1,56 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for advice generation."""
+from typing import List
+
+import pytest
+
+from mlia.core.advice_generation import Advice
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import ExecutionContext
+from mlia.devices.tosa.advice_generation import TOSAAdviceProducer
+from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
+from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
+
+
+@pytest.mark.parametrize(
+ "input_data, advice_category, expected_advice",
+ [
+ [
+ ModelIsNotTOSACompatible(),
+ AdviceCategory.OPERATORS,
+ [
+ Advice(
+ [
+ "Some operators in the model are not TOSA compatible. "
+ "Please, refer to the operators table for more information."
+ ]
+ )
+ ],
+ ],
+ [
+ ModelIsTOSACompatible(),
+ AdviceCategory.OPERATORS,
+ [Advice(["Model is fully TOSA compatible."])],
+ ],
+ ],
+)
+def test_tosa_advice_producer(
+ tmpdir: str,
+ input_data: DataItem,
+ advice_category: AdviceCategory,
+ expected_advice: List[Advice],
+) -> None:
+ """Test TOSA advice producer."""
+ producer = TOSAAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ )
+
+ producer.set_context(context)
+ producer.produce_advice(input_data)
+
+ assert producer.get_advice() == expected_advice
diff --git a/tests/test_devices_tosa_advisor.py b/tests/test_devices_tosa_advisor.py
new file mode 100644
index 0000000..1c7a31a
--- /dev/null
+++ b/tests/test_devices_tosa_advisor.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA advisor."""
+from pathlib import Path
+
+from mlia.core.context import ExecutionContext
+from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor
+from mlia.devices.tosa.advisor import TOSAInferenceAdvisor
+
+
+def test_configure_and_get_tosa_advisor(test_tflite_model: Path) -> None:
+ """Test TOSA advisor configuration."""
+ ctx = ExecutionContext()
+
+ advisor = configure_and_get_tosa_advisor(ctx, "tosa", test_tflite_model)
+ workflow = advisor.configure(ctx)
+
+ assert isinstance(advisor, TOSAInferenceAdvisor)
+
+ assert ctx.event_handlers is not None
+ assert ctx.config_parameters == {
+ "tosa_inference_advisor": {
+ "model": str(test_tflite_model),
+ "target_profile": "tosa",
+ }
+ }
+
+ assert isinstance(workflow, DefaultWorkflowExecutor)
diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py
new file mode 100644
index 0000000..60bcee8
--- /dev/null
+++ b/tests/test_devices_tosa_data_analysis.py
@@ -0,0 +1,33 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA data analysis module."""
+from typing import List
+
+import pytest
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
+from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
+from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_facts",
+ [
+ [
+ TOSACompatibilityInfo(True, []),
+ [ModelIsTOSACompatible()],
+ ],
+ [
+ TOSACompatibilityInfo(False, []),
+ [ModelIsNotTOSACompatible()],
+ ],
+ ],
+)
+def test_tosa_data_analyzer(input_data: DataItem, expected_facts: List[Fact]) -> None:
+ """Test TOSA data analyzer."""
+ analyzer = TOSADataAnalyzer()
+ analyzer.analyze_data(input_data)
+ assert analyzer.get_analyzed_data() == expected_facts
diff --git a/tests/test_devices_tosa_data_collection.py b/tests/test_devices_tosa_data_collection.py
new file mode 100644
index 0000000..b9c0b4c
--- /dev/null
+++ b/tests/test_devices_tosa_data_collection.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA data collection module."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+def test_tosa_data_collection(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str
+) -> None:
+ """Test TOSA data collection."""
+ monkeypatch.setattr(
+ "mlia.devices.tosa.data_collection.get_tosa_compatibility_info",
+ MagicMock(return_value=TOSACompatibilityInfo(True, [])),
+ )
+ context = ExecutionContext(working_dir=tmpdir)
+ collector = TOSAOperatorCompatibility(test_tflite_model)
+ collector.set_context(context)
+
+ data_item = collector.collect_data()
+
+ assert isinstance(data_item, TOSACompatibilityInfo)
diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py
new file mode 100644
index 0000000..b7736d2
--- /dev/null
+++ b/tests/test_devices_tosa_operators.py
@@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA compatibility."""
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any
+from typing import Optional
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.tosa.operators import get_tosa_compatibility_info
+from mlia.devices.tosa.operators import Operator
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+def replace_get_tosa_checker_with_mock(
+ monkeypatch: pytest.MonkeyPatch, mock: Optional[MagicMock]
+) -> None:
+ """Replace TOSA checker with mock."""
+ monkeypatch.setattr(
+ "mlia.devices.tosa.operators.get_tosa_checker", MagicMock(return_value=mock)
+ )
+
+
+def test_compatibility_check_should_fail_if_checker_not_available(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path
+) -> None:
+ """Test that compatibility check should fail if TOSA checker is not available."""
+ replace_get_tosa_checker_with_mock(monkeypatch, None)
+
+ with pytest.raises(Exception, match="TOSA checker is not available"):
+ get_tosa_compatibility_info(test_tflite_model)
+
+
+@pytest.mark.parametrize(
+ "is_tosa_compatible, operators, expected_result",
+ [
+ [
+ True,
+ [],
+ TOSACompatibilityInfo(True, []),
+ ],
+ [
+ True,
+ [
+ SimpleNamespace(
+ location="op_location",
+ name="op_name",
+ is_tosa_compatible=True,
+ )
+ ],
+ TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]),
+ ],
+ [
+ False,
+ [
+ SimpleNamespace(
+ location="op_location",
+ name="op_name",
+ is_tosa_compatible=False,
+ )
+ ],
+ TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]),
+ ],
+ ],
+)
+def test_get_tosa_compatibility_info(
+ monkeypatch: pytest.MonkeyPatch,
+ test_tflite_model: Path,
+ is_tosa_compatible: bool,
+ operators: Any,
+ expected_result: TOSACompatibilityInfo,
+) -> None:
+ """Test getting TOSA compatibility information."""
+ mock_checker = MagicMock()
+ mock_checker.is_tosa_compatible.return_value = is_tosa_compatible
+ mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access
+ operators
+ )
+
+ replace_get_tosa_checker_with_mock(monkeypatch, mock_checker)
+
+ assert get_tosa_compatibility_info(test_tflite_model) == expected_result
diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py
index 7cf32e7..fb894db 100644
--- a/tests/test_utils_filesystem.py
+++ b/tests/test_utils_filesystem.py
@@ -46,6 +46,7 @@ def test_profiles_data() -> None:
"ethos-u55-256",
"ethos-u55-128",
"ethos-u65-512",
+ "tosa",
]
@@ -72,6 +73,7 @@ def test_get_supported_profile_names() -> None:
"ethos-u55-256",
"ethos-u55-128",
"ethos-u65-512",
+ "tosa",
]