diff options
-rw-r--r-- | src/mlia/api.py | 2 | ||||
-rw-r--r-- | src/mlia/devices/tosa/__init__.py | 3 | ||||
-rw-r--r-- | src/mlia/devices/tosa/advice_generation.py | 40 | ||||
-rw-r--r-- | src/mlia/devices/tosa/advisor.py | 98 | ||||
-rw-r--r-- | src/mlia/devices/tosa/config.py | 19 | ||||
-rw-r--r-- | src/mlia/devices/tosa/data_analysis.py | 36 | ||||
-rw-r--r-- | src/mlia/devices/tosa/data_collection.py | 35 | ||||
-rw-r--r-- | src/mlia/devices/tosa/events.py | 24 | ||||
-rw-r--r-- | src/mlia/devices/tosa/handlers.py | 35 | ||||
-rw-r--r-- | src/mlia/devices/tosa/operators.py | 70 | ||||
-rw-r--r-- | src/mlia/devices/tosa/reporters.py | 94 | ||||
-rw-r--r-- | src/mlia/resources/profiles.json | 3 | ||||
-rw-r--r-- | tests/test_api.py | 4 | ||||
-rw-r--r-- | tests/test_devices_tosa_advice_generation.py | 56 | ||||
-rw-r--r-- | tests/test_devices_tosa_advisor.py | 29 | ||||
-rw-r--r-- | tests/test_devices_tosa_data_analysis.py | 33 | ||||
-rw-r--r-- | tests/test_devices_tosa_data_collection.py | 28 | ||||
-rw-r--r-- | tests/test_devices_tosa_operators.py | 84 | ||||
-rw-r--r-- | tests/test_utils_filesystem.py | 2 |
19 files changed, 695 insertions, 0 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index 024bc98..c720b8d 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -15,6 +15,7 @@ from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor +from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor from mlia.utils.filesystem import get_target @@ -104,6 +105,7 @@ def get_advisor( target_factories = { "ethos-u55": configure_and_get_ethosu_advisor, "ethos-u65": configure_and_get_ethosu_advisor, + "tosa": configure_and_get_tosa_advisor, } try: diff --git a/src/mlia/devices/tosa/__init__.py b/src/mlia/devices/tosa/__init__.py new file mode 100644 index 0000000..762c831 --- /dev/null +++ b/src/mlia/devices/tosa/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA target module.""" diff --git a/src/mlia/devices/tosa/advice_generation.py b/src/mlia/devices/tosa/advice_generation.py new file mode 100644 index 0000000..7adfcb9 --- /dev/null +++ b/src/mlia/devices/tosa/advice_generation.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA advice generation.""" +from functools import singledispatchmethod + +from mlia.core.advice_generation import advice_category +from mlia.core.advice_generation import FactBasedAdviceProducer +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible +from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible + + +class TOSAAdviceProducer(FactBasedAdviceProducer): + """TOSA advice producer.""" + + @singledispatchmethod + def produce_advice(self, _data_item: DataItem) -> None: + """Produce advice.""" + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_is_tosa_compatible( + self, _data_item: ModelIsTOSACompatible + ) -> None: + """Advice for TOSA compatibility.""" + self.add_advice(["Model is fully TOSA compatible."]) + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_is_not_tosa_compatible( + self, _data_item: ModelIsNotTOSACompatible + ) -> None: + """Advice for TOSA compatibility.""" + self.add_advice( + [ + "Some operators in the model are not TOSA compatible. " + "Please, refer to the operators table for more information." + ] + ) diff --git a/src/mlia/devices/tosa/advisor.py b/src/mlia/devices/tosa/advisor.py new file mode 100644 index 0000000..6a32b94 --- /dev/null +++ b/src/mlia/devices/tosa/advisor.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA advisor.""" +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from mlia.core._typing import PathOrFileLike +from mlia.core.advice_generation import AdviceCategory +from mlia.core.advice_generation import AdviceProducer +from mlia.core.advisor import DefaultInferenceAdvisor +from mlia.core.advisor import InferenceAdvisor +from mlia.core.context import Context +from mlia.core.context import ExecutionContext +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.events import Event +from mlia.devices.tosa.advice_generation import TOSAAdviceProducer +from mlia.devices.tosa.config import TOSAConfiguration +from mlia.devices.tosa.data_analysis import TOSADataAnalyzer +from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility +from mlia.devices.tosa.events import TOSAAdvisorStartedEvent +from mlia.devices.tosa.handlers import TOSAEventHandler + + +class TOSAInferenceAdvisor(DefaultInferenceAdvisor): + """TOSA inference advisor.""" + + @classmethod + def name(cls) -> str: + """Return name of the advisor.""" + return "tosa_inference_advisor" + + def get_collectors(self, context: Context) -> List[DataCollector]: + """Return list of the data collectors.""" + model = self.get_model(context) + + collectors: List[DataCollector] = [] + + if AdviceCategory.OPERATORS in context.advice_category: + collectors.append(TOSAOperatorCompatibility(model)) + + return collectors + + def get_analyzers(self, context: Context) -> List[DataAnalyzer]: + """Return list of the data analyzers.""" + return [ + TOSADataAnalyzer(), + ] + + def get_producers(self, context: Context) -> List[AdviceProducer]: + """Return list of the advice producers.""" + return [ + TOSAAdviceProducer(), + ] + + def get_events(self, context: Context) -> List[Event]: + """Return list of the startup events.""" + model = self.get_model(context) + target_profile = self.get_target_profile(context) + + return [ + TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)), + ] + + +def configure_and_get_tosa_advisor( + context: ExecutionContext, + target_profile: str, + model: Union[Path, str], + output: Optional[PathOrFileLike] = None, + **_extra_args: Any +) -> InferenceAdvisor: + """Create and configure TOSA advisor.""" + if context.event_handlers is None: + context.event_handlers = [TOSAEventHandler(output)] + + if context.config_parameters is None: + context.config_parameters = _get_config_parameters(model, target_profile) + + return TOSAInferenceAdvisor() + + +def _get_config_parameters( + model: Union[Path, str], target_profile: str +) -> Dict[str, Any]: + """Get configuration parameters for the advisor.""" + advisor_parameters: Dict[str, Any] = { + "tosa_inference_advisor": { + "model": str(model), + "target_profile": target_profile, + } + } + + return advisor_parameters diff --git a/src/mlia/devices/tosa/config.py b/src/mlia/devices/tosa/config.py new file mode 100644 index 0000000..c3879a7 --- /dev/null +++ b/src/mlia/devices/tosa/config.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA target configuration.""" +from mlia.devices.config import IPConfiguration +from mlia.utils.filesystem import get_profile + + +class TOSAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods + """TOSA configuration.""" + + def __init__(self, target_profile: str) -> None: + """Init configuration.""" + target_data = get_profile(target_profile) + target = target_data["target"] + + if target != "tosa": + raise Exception(f"Wrong target {target} for TOSA configuration") + + super().__init__(target) diff --git a/src/mlia/devices/tosa/data_analysis.py b/src/mlia/devices/tosa/data_analysis.py new file mode 100644 index 0000000..aa696a5 --- /dev/null +++ b/src/mlia/devices/tosa/data_analysis.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA data analysis module.""" +from dataclasses import dataclass +from functools import singledispatchmethod + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.core.data_analysis import FactExtractor +from mlia.devices.tosa.operators import TOSACompatibilityInfo + + +@dataclass +class ModelIsTOSACompatible(Fact): + """Model is completely TOSA compatible.""" + + +@dataclass +class ModelIsNotTOSACompatible(Fact): + """Model is not TOSA compatible.""" + + +class TOSADataAnalyzer(FactExtractor): + """TOSA data analyzer.""" + + @singledispatchmethod + def analyze_data(self, data_item: DataItem) -> None: + """Analyse the data.""" + + @analyze_data.register + def analyze_tosa_compatibility(self, data_item: TOSACompatibilityInfo) -> None: + """Analyse TOSA compatibility information.""" + if data_item.tosa_compatible: + self.add_fact(ModelIsTOSACompatible()) + else: + self.add_fact(ModelIsNotTOSACompatible()) diff --git a/src/mlia/devices/tosa/data_collection.py b/src/mlia/devices/tosa/data_collection.py new file mode 100644 index 0000000..843d5ab --- /dev/null +++ b/src/mlia/devices/tosa/data_collection.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA data collection module.""" +import logging +from pathlib import Path + +from mlia.core.data_collection import ContextAwareDataCollector +from mlia.devices.tosa.operators import get_tosa_compatibility_info +from mlia.devices.tosa.operators import TOSACompatibilityInfo +from mlia.nn.tensorflow.config import get_tflite_model + +logger = logging.getLogger(__name__) + + +class TOSAOperatorCompatibility(ContextAwareDataCollector): + """Collect operator compatibility information.""" + + def __init__(self, model: Path) -> None: + """Init the data collector.""" + self.model = model + + def collect_data(self) -> TOSACompatibilityInfo: + """Collect TOSA compatibility information.""" + tflite_model = get_tflite_model(self.model, self.context) + + logger.info("Checking operator compatibility ...") + tosa_info = get_tosa_compatibility_info(tflite_model.model_path) + logger.info("Done\n") + + return tosa_info + + @classmethod + def name(cls) -> str: + """Return name of the collector.""" + return "tosa_operator_compatibility" diff --git a/src/mlia/devices/tosa/events.py b/src/mlia/devices/tosa/events.py new file mode 100644 index 0000000..ceaba57 --- /dev/null +++ b/src/mlia/devices/tosa/events.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA advisor events.""" +from dataclasses import dataclass +from pathlib import Path + +from mlia.core.events import Event +from mlia.core.events import EventDispatcher +from mlia.devices.tosa.config import TOSAConfiguration + + +@dataclass +class TOSAAdvisorStartedEvent(Event): + """Event with TOSA advisor parameters.""" + + model: Path + device: TOSAConfiguration + + +class TOSAAdvisorEventHandler(EventDispatcher): + """Event handler for the TOSA inference advisor.""" + + def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None: + """Handle TOSAAdvisorStartedEvent event.""" diff --git a/src/mlia/devices/tosa/handlers.py b/src/mlia/devices/tosa/handlers.py new file mode 100644 index 0000000..00c18c5 --- /dev/null +++ b/src/mlia/devices/tosa/handlers.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA Advisor event handlers.""" +# pylint: disable=R0801 +import logging +from typing import Optional + +from mlia.core._typing import PathOrFileLike +from mlia.core.events import CollectedDataEvent +from mlia.core.handlers import WorkflowEventsHandler +from mlia.devices.tosa.events import TOSAAdvisorEventHandler +from mlia.devices.tosa.events import TOSAAdvisorStartedEvent +from mlia.devices.tosa.operators import TOSACompatibilityInfo +from mlia.devices.tosa.reporters import tosa_formatters + +logger = logging.getLogger(__name__) + + +class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): + """Event handler for TOSA advisor.""" + + def __init__(self, output: Optional[PathOrFileLike] = None) -> None: + """Init event handler.""" + super().__init__(tosa_formatters, output) + + def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None: + """Handle TOSAAdvisorStartedEvent event.""" + self.reporter.submit(event.device) + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedDataEvent event.""" + data_item = event.data_item + + if isinstance(data_item, TOSACompatibilityInfo): + self.reporter.submit(data_item.operators, delay_print=True) diff --git a/src/mlia/devices/tosa/operators.py b/src/mlia/devices/tosa/operators.py new file mode 100644 index 0000000..4f3df10 --- /dev/null +++ b/src/mlia/devices/tosa/operators.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Operators module.""" +from dataclasses import dataclass +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Protocol + +from mlia.core._typing import PathOrFileLike + + +class TOSAChecker(Protocol): + """TOSA checker protocol.""" + + def is_tosa_compatible(self) -> bool: + """Return true if model is TOSA compatible.""" + + def _get_tosa_compatibility_for_ops(self) -> List[Any]: + """Return list of operators.""" + + +@dataclass +class Operator: + """Operator's TOSA compatibility info.""" + + location: str + name: str + is_tosa_compatible: bool + + +@dataclass +class TOSACompatibilityInfo: + """Models' TOSA compatibility information.""" + + tosa_compatible: bool + operators: List[Operator] + + +def get_tosa_compatibility_info( + tflite_model_path: PathOrFileLike, +) -> TOSACompatibilityInfo: + """Return list of the operators.""" + checker = get_tosa_checker(tflite_model_path) + + if checker is None: + raise Exception( + "TOSA checker is not available. " + "Please make sure that 'tosa_checker' package is installed: " + "pip install mlia[tosa]" + ) + + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + + return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + + +def get_tosa_checker(tflite_model_path: PathOrFileLike) -> Optional[TOSAChecker]: + """Return instance of the TOSA checker.""" + try: + import tosa_checker as tc # pylint: disable=import-outside-toplevel + except ImportError: + return None + + checker = tc.TOSAChecker(str(tflite_model_path)) + return cast(TOSAChecker, checker) diff --git a/src/mlia/devices/tosa/reporters.py b/src/mlia/devices/tosa/reporters.py new file mode 100644 index 0000000..8fba95c --- /dev/null +++ b/src/mlia/devices/tosa/reporters.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Reports module.""" +from typing import Any +from typing import Callable +from typing import List + +from mlia.core.advice_generation import Advice +from mlia.core.reporting import Cell +from mlia.core.reporting import Column +from mlia.core.reporting import Format +from mlia.core.reporting import NestedReport +from mlia.core.reporting import Report +from mlia.core.reporting import ReportItem +from mlia.core.reporting import Table +from mlia.devices.tosa.config import TOSAConfiguration +from mlia.devices.tosa.operators import Operator +from mlia.utils.console import style_improvement +from mlia.utils.types import is_list_of + + +def report_device(device: TOSAConfiguration) -> Report: + """Generate report for the device.""" + return NestedReport( + "Device information", + "device", + [ + ReportItem("Target", alias="target", value=device.target), + ], + ) + + +def report_advice(advice: List[Advice]) -> Report: + """Generate report for the advice.""" + return Table( + columns=[ + Column("#", only_for=["plain_text"]), + Column("Advice", alias="advice_message"), + ], + rows=[(i + 1, a.messages) for i, a in enumerate(advice)], + name="Advice", + alias="advice", + ) + + +def report_tosa_operators(ops: List[Operator]) -> Report: + """Generate report for the operators.""" + return Table( + [ + Column("#", only_for=["plain_text"]), + Column( + "Operator location", + alias="operator_location", + fmt=Format(wrap_width=30), + ), + Column("Operator name", alias="operator_name", fmt=Format(wrap_width=20)), + Column( + "TOSA compatibility", + alias="is_tosa_compatible", + fmt=Format(wrap_width=25), + ), + ], + [ + ( + index + 1, + op.location, + op.name, + Cell( + op.is_tosa_compatible, + Format( + style=style_improvement(op.is_tosa_compatible), + str_fmt=lambda v: "Compatible" if v else "Not compatible", + ), + ), + ) + for index, op in enumerate(ops) + ], + name="Operators", + alias="operators", + ) + + +def tosa_formatters(data: Any) -> Callable[[Any], Report]: + """Find appropriate formatter for the provided data.""" + if is_list_of(data, Advice): + return report_advice + + if isinstance(data, TOSAConfiguration): + return report_device + + if is_list_of(data, Operator): + return report_tosa_operators + + raise Exception(f"Unable to find appropriate formatter for {data}") diff --git a/src/mlia/resources/profiles.json b/src/mlia/resources/profiles.json index 4493d7b..b2a3351 100644 --- a/src/mlia/resources/profiles.json +++ b/src/mlia/resources/profiles.json @@ -16,5 +16,8 @@ "mac": 512, "system_config": "Ethos_U65_High_End", "memory_mode": "Dedicated_Sram" + }, + "tosa": { + "target": "tosa" } } diff --git a/tests/test_api.py b/tests/test_api.py index e8df7af..7b567bf 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -12,6 +12,7 @@ from mlia.core.common import AdviceCategory from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor +from mlia.devices.tosa.advisor import TOSAInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: @@ -103,3 +104,6 @@ def test_get_advisor( ExecutionContext(), "ethos-u55-256", str(test_keras_model) ) assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor) + + tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model)) + assert isinstance(tosa_advisor, TOSAInferenceAdvisor) diff --git a/tests/test_devices_tosa_advice_generation.py b/tests/test_devices_tosa_advice_generation.py new file mode 100644 index 0000000..018ba57 --- /dev/null +++ b/tests/test_devices_tosa_advice_generation.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for advice generation.""" +from typing import List + +import pytest + +from mlia.core.advice_generation import Advice +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.core.context import ExecutionContext +from mlia.devices.tosa.advice_generation import TOSAAdviceProducer +from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible +from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible + + +@pytest.mark.parametrize( + "input_data, advice_category, expected_advice", + [ + [ + ModelIsNotTOSACompatible(), + AdviceCategory.OPERATORS, + [ + Advice( + [ + "Some operators in the model are not TOSA compatible. " + "Please, refer to the operators table for more information." + ] + ) + ], + ], + [ + ModelIsTOSACompatible(), + AdviceCategory.OPERATORS, + [Advice(["Model is fully TOSA compatible."])], + ], + ], +) +def test_tosa_advice_producer( + tmpdir: str, + input_data: DataItem, + advice_category: AdviceCategory, + expected_advice: List[Advice], +) -> None: + """Test TOSA advice producer.""" + producer = TOSAAdviceProducer() + + context = ExecutionContext( + advice_category=advice_category, + working_dir=tmpdir, + ) + + producer.set_context(context) + producer.produce_advice(input_data) + + assert producer.get_advice() == expected_advice diff --git a/tests/test_devices_tosa_advisor.py b/tests/test_devices_tosa_advisor.py new file mode 100644 index 0000000..1c7a31a --- /dev/null +++ b/tests/test_devices_tosa_advisor.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA advisor.""" +from pathlib import Path + +from mlia.core.context import ExecutionContext +from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor +from mlia.devices.tosa.advisor import TOSAInferenceAdvisor + + +def test_configure_and_get_tosa_advisor(test_tflite_model: Path) -> None: + """Test TOSA advisor configuration.""" + ctx = ExecutionContext() + + advisor = configure_and_get_tosa_advisor(ctx, "tosa", test_tflite_model) + workflow = advisor.configure(ctx) + + assert isinstance(advisor, TOSAInferenceAdvisor) + + assert ctx.event_handlers is not None + assert ctx.config_parameters == { + "tosa_inference_advisor": { + "model": str(test_tflite_model), + "target_profile": "tosa", + } + } + + assert isinstance(workflow, DefaultWorkflowExecutor) diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py new file mode 100644 index 0000000..60bcee8 --- /dev/null +++ b/tests/test_devices_tosa_data_analysis.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA data analysis module.""" +from typing import List + +import pytest + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible +from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible +from mlia.devices.tosa.data_analysis import TOSADataAnalyzer +from mlia.devices.tosa.operators import TOSACompatibilityInfo + + +@pytest.mark.parametrize( + "input_data, expected_facts", + [ + [ + TOSACompatibilityInfo(True, []), + [ModelIsTOSACompatible()], + ], + [ + TOSACompatibilityInfo(False, []), + [ModelIsNotTOSACompatible()], + ], + ], +) +def test_tosa_data_analyzer(input_data: DataItem, expected_facts: List[Fact]) -> None: + """Test TOSA data analyzer.""" + analyzer = TOSADataAnalyzer() + analyzer.analyze_data(input_data) + assert analyzer.get_analyzed_data() == expected_facts diff --git a/tests/test_devices_tosa_data_collection.py b/tests/test_devices_tosa_data_collection.py new file mode 100644 index 0000000..b9c0b4c --- /dev/null +++ b/tests/test_devices_tosa_data_collection.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA data collection module.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.core.context import ExecutionContext +from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility +from mlia.devices.tosa.operators import TOSACompatibilityInfo + + +def test_tosa_data_collection( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str +) -> None: + """Test TOSA data collection.""" + monkeypatch.setattr( + "mlia.devices.tosa.data_collection.get_tosa_compatibility_info", + MagicMock(return_value=TOSACompatibilityInfo(True, [])), + ) + context = ExecutionContext(working_dir=tmpdir) + collector = TOSAOperatorCompatibility(test_tflite_model) + collector.set_context(context) + + data_item = collector.collect_data() + + assert isinstance(data_item, TOSACompatibilityInfo) diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py new file mode 100644 index 0000000..b7736d2 --- /dev/null +++ b/tests/test_devices_tosa_operators.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA compatibility.""" +from pathlib import Path +from types import SimpleNamespace +from typing import Any +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +from mlia.devices.tosa.operators import get_tosa_compatibility_info +from mlia.devices.tosa.operators import Operator +from mlia.devices.tosa.operators import TOSACompatibilityInfo + + +def replace_get_tosa_checker_with_mock( + monkeypatch: pytest.MonkeyPatch, mock: Optional[MagicMock] +) -> None: + """Replace TOSA checker with mock.""" + monkeypatch.setattr( + "mlia.devices.tosa.operators.get_tosa_checker", MagicMock(return_value=mock) + ) + + +def test_compatibility_check_should_fail_if_checker_not_available( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path +) -> None: + """Test that compatibility check should fail if TOSA checker is not available.""" + replace_get_tosa_checker_with_mock(monkeypatch, None) + + with pytest.raises(Exception, match="TOSA checker is not available"): + get_tosa_compatibility_info(test_tflite_model) + + +@pytest.mark.parametrize( + "is_tosa_compatible, operators, expected_result", + [ + [ + True, + [], + TOSACompatibilityInfo(True, []), + ], + [ + True, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=True, + ) + ], + TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]), + ], + [ + False, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=False, + ) + ], + TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]), + ], + ], +) +def test_get_tosa_compatibility_info( + monkeypatch: pytest.MonkeyPatch, + test_tflite_model: Path, + is_tosa_compatible: bool, + operators: Any, + expected_result: TOSACompatibilityInfo, +) -> None: + """Test getting TOSA compatibility information.""" + mock_checker = MagicMock() + mock_checker.is_tosa_compatible.return_value = is_tosa_compatible + mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access + operators + ) + + replace_get_tosa_checker_with_mock(monkeypatch, mock_checker) + + assert get_tosa_compatibility_info(test_tflite_model) == expected_result diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py index 7cf32e7..fb894db 100644 --- a/tests/test_utils_filesystem.py +++ b/tests/test_utils_filesystem.py @@ -46,6 +46,7 @@ def test_profiles_data() -> None: "ethos-u55-256", "ethos-u55-128", "ethos-u65-512", + "tosa", ] @@ -72,6 +73,7 @@ def test_get_supported_profile_names() -> None: "ethos-u55-256", "ethos-u55-128", "ethos-u65-512", + "tosa", ] |