diff options
author | Raul Farkas <raul.farkas@arm.com> | 2022-10-12 10:48:35 +0100 |
---|---|---|
committer | Raul Farkas <raul.farkas@arm.com> | 2022-10-18 14:29:23 +0100 |
commit | 4fa21325ec498adbf467876c2413c054d0e85c5b (patch) | |
tree | 892d68467bc6666e9daf74ab4d141810dcec1ac6 | |
parent | 89c1f4bafb51dbbed705b6960810d90825318b13 (diff) | |
download | mlia-4fa21325ec498adbf467876c2413c054d0e85c5b.tar.gz |
MLIA-409 Create new Cortex-A device skeleton
* Add Cortex-A device skeleton
* Add unit tests for the Cortex-A device skeleton
* Update profiles.json by adding the new "cortex-a" profile
* Add new cortex-a factory to the get_advisor method in api.py
* Disable performance and optimization commands for the cortex-a
profile.
* Update trademarks section in README.md
* Update pyproject.toml to not run similarity check in imports
Change-Id: I2e228aaada1e2d3c5cc329d70572b51962ff517f
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | pyproject.toml | 3 | ||||
-rw-r--r-- | src/mlia/api.py | 2 | ||||
-rw-r--r-- | src/mlia/cli/main.py | 4 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/__init__.py | 3 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/advice_generation.py | 40 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/advisor.py | 94 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/config.py | 20 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/data_analysis.py | 38 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/data_collection.py | 36 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/events.py | 24 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/handlers.py | 35 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/operators.py | 29 | ||||
-rw-r--r-- | src/mlia/devices/cortexa/reporters.py | 42 | ||||
-rw-r--r-- | src/mlia/resources/profiles.json | 3 | ||||
-rw-r--r-- | tests/test_devices_cortex_a_advice_generation.py | 56 | ||||
-rw-r--r-- | tests/test_devices_cortex_a_data_analysis.py | 35 | ||||
-rw-r--r-- | tests/test_devices_cortex_a_data_collection.py | 28 | ||||
-rw-r--r-- | tests/test_devices_cortexa_advisor.py | 34 | ||||
-rw-r--r-- | tests/test_utils_filesystem.py | 2 |
20 files changed, 527 insertions, 3 deletions
@@ -405,7 +405,7 @@ ML Inference Advisor is licensed under [Apache License 2.0](LICENSE.txt). ## Trademarks and copyrights -* Arm®, Ethos™-U, Cortex®-M, Corstone™ are registered trademarks or +* Arm®, Ethos™-U, Cortex®-A, Cortex®-M, Corstone™ are registered trademarks or trademarks of Arm® Limited (or its subsidiaries) in the U.S. and/or elsewhere. * TensorFlow™ is a trademark of Google® LLC. diff --git a/pyproject.toml b/pyproject.toml index 1dcbf21..52b7f41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ enable = [ "line-too-long" # C0301 ] +[tool.pylint.similarities] +ignore-imports = true + [tool.mypy] # Suppresses error messages about imports that cannot be resolved ignore_missing_imports = true diff --git a/src/mlia/api.py b/src/mlia/api.py index 878e316..fc61af0 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -12,6 +12,7 @@ from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.typing import PathOrFileLike +from mlia.devices.cortexa.advisor import configure_and_get_cortexa_advisor from mlia.devices.ethosu.advisor import configure_and_get_ethosu_advisor from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor from mlia.utils.filesystem import get_target @@ -104,6 +105,7 @@ def get_advisor( "ethos-u55": configure_and_get_ethosu_advisor, "ethos-u65": configure_and_get_ethosu_advisor, "tosa": configure_and_get_tosa_advisor, + "cortex-a": configure_and_get_cortexa_advisor, } try: diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 0ece289..bafe434 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -79,7 +79,7 @@ def get_commands() -> list[CommandInfo]: performance, ["perf"], [ - partial(add_target_options, profiles_to_skip=["tosa"]), + partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), add_tflite_model_options, add_output_options, add_debug_options, @@ -90,7 +90,7 @@ def get_commands() -> list[CommandInfo]: optimization, ["opt"], [ - partial(add_target_options, profiles_to_skip=["tosa"]), + partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), add_keras_model_options, add_multi_optimization_options, add_output_options, diff --git a/src/mlia/devices/cortexa/__init__.py b/src/mlia/devices/cortexa/__init__.py new file mode 100644 index 0000000..3a987e7 --- /dev/null +++ b/src/mlia/devices/cortexa/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A devices module.""" diff --git a/src/mlia/devices/cortexa/advice_generation.py b/src/mlia/devices/cortexa/advice_generation.py new file mode 100644 index 0000000..33d5a5f --- /dev/null +++ b/src/mlia/devices/cortexa/advice_generation.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A advice generation.""" +from functools import singledispatchmethod + +from mlia.core.advice_generation import advice_category +from mlia.core.advice_generation import FactBasedAdviceProducer +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible +from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible + + +class CortexAAdviceProducer(FactBasedAdviceProducer): + """Cortex-A advice producer.""" + + @singledispatchmethod + def produce_advice(self, _data_item: DataItem) -> None: + """Produce advice.""" + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_is_cortex_a_compatible( + self, _data_item: ModelIsCortexACompatible + ) -> None: + """Advice for Cortex-A compatibility.""" + self.add_advice(["Model is fully compatible with Cortex-A."]) + + @produce_advice.register + @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + def handle_model_is_not_cortex_a_compatible( + self, _data_item: ModelIsNotCortexACompatible + ) -> None: + """Advice for Cortex-A compatibility.""" + self.add_advice( + [ + "Some operators in the model are not compatible with Cortex-A. " + "Please, refer to the operators table for more information." + ] + ) diff --git a/src/mlia/devices/cortexa/advisor.py b/src/mlia/devices/cortexa/advisor.py new file mode 100644 index 0000000..98c155b --- /dev/null +++ b/src/mlia/devices/cortexa/advisor.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A MLIA module.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from mlia.core.advice_generation import AdviceProducer +from mlia.core.advisor import DefaultInferenceAdvisor +from mlia.core.advisor import InferenceAdvisor +from mlia.core.common import AdviceCategory +from mlia.core.context import Context +from mlia.core.context import ExecutionContext +from mlia.core.data_analysis import DataAnalyzer +from mlia.core.data_collection import DataCollector +from mlia.core.events import Event +from mlia.core.typing import PathOrFileLike +from mlia.devices.cortexa.advice_generation import CortexAAdviceProducer +from mlia.devices.cortexa.config import CortexAConfiguration +from mlia.devices.cortexa.data_analysis import CortexADataAnalyzer +from mlia.devices.cortexa.data_collection import CortexAOperatorCompatibility +from mlia.devices.cortexa.events import CortexAAdvisorStartedEvent +from mlia.devices.cortexa.handlers import CortexAEventHandler + + +class CortexAInferenceAdvisor(DefaultInferenceAdvisor): + """Cortex-A Inference Advisor.""" + + @classmethod + def name(cls) -> str: + """Return name of the advisor.""" + return "cortex_a_inference_advisor" + + def get_collectors(self, context: Context) -> list[DataCollector]: + """Return list of the data collectors.""" + model = self.get_model(context) + + collectors: list[DataCollector] = [] + + if AdviceCategory.OPERATORS in context.advice_category: + collectors.append(CortexAOperatorCompatibility(model)) + + return collectors + + def get_analyzers(self, context: Context) -> list[DataAnalyzer]: + """Return list of the data analyzers.""" + return [ + CortexADataAnalyzer(), + ] + + def get_producers(self, context: Context) -> list[AdviceProducer]: + """Return list of the advice producers.""" + return [CortexAAdviceProducer()] + + def get_events(self, context: Context) -> list[Event]: + """Return list of the startup events.""" + model = self.get_model(context) + target_profile = self.get_target_profile(context) + + return [ + CortexAAdvisorStartedEvent(model, CortexAConfiguration(target_profile)), + ] + + +def configure_and_get_cortexa_advisor( + context: ExecutionContext, + target_profile: str, + model: str | Path, + output: PathOrFileLike | None = None, + **extra_args: Any, +) -> InferenceAdvisor: + """Create and configure Cortex-A advisor.""" + if context.event_handlers is None: + context.event_handlers = [CortexAEventHandler(output)] + + if context.config_parameters is None: + context.config_parameters = _get_config_parameters( + model, target_profile, **extra_args + ) + + return CortexAInferenceAdvisor() + + +def _get_config_parameters(model: str | Path, target_profile: str) -> dict[str, Any]: + """Get configuration parameters for the advisor.""" + advisor_parameters: dict[str, Any] = { + "cortex_a_inference_advisor": { + "model": str(model), + "target_profile": target_profile, + }, + } + + return advisor_parameters diff --git a/src/mlia/devices/cortexa/config.py b/src/mlia/devices/cortexa/config.py new file mode 100644 index 0000000..ec0cf0a --- /dev/null +++ b/src/mlia/devices/cortexa/config.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A configuration.""" +from __future__ import annotations + +from mlia.devices.config import IPConfiguration +from mlia.utils.filesystem import get_profile + + +class CortexAConfiguration(IPConfiguration): # pylint: disable=too-few-public-methods + """Cortex-A configuration.""" + + def __init__(self, target_profile: str) -> None: + """Init Cortex-A target configuration.""" + target_data = get_profile(target_profile) + + target = target_data["target"] + if target != "cortex-a": + raise Exception(f"Wrong target {target} for Cortex-A configuration") + super().__init__(target) diff --git a/src/mlia/devices/cortexa/data_analysis.py b/src/mlia/devices/cortexa/data_analysis.py new file mode 100644 index 0000000..dff95ce --- /dev/null +++ b/src/mlia/devices/cortexa/data_analysis.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A data analysis module.""" +from dataclasses import dataclass +from functools import singledispatchmethod + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.core.data_analysis import FactExtractor +from mlia.devices.cortexa.operators import CortexACompatibilityInfo + + +class CortexADataAnalyzer(FactExtractor): + """Cortex-A data analyzer.""" + + @singledispatchmethod + def analyze_data(self, data_item: DataItem) -> None: + """Analyse the data.""" + + @analyze_data.register + def analyze_operator_compatibility( + self, data_item: CortexACompatibilityInfo + ) -> None: + """Analyse operator compatibility information.""" + if data_item.cortex_a_compatible: + self.add_fact(ModelIsCortexACompatible()) + else: + self.add_fact(ModelIsNotCortexACompatible()) + + +@dataclass +class ModelIsCortexACompatible(Fact): + """Model is completely compatible with Cortex-A.""" + + +@dataclass +class ModelIsNotCortexACompatible(Fact): + """Model is not compatible with Cortex-A.""" diff --git a/src/mlia/devices/cortexa/data_collection.py b/src/mlia/devices/cortexa/data_collection.py new file mode 100644 index 0000000..00c95e6 --- /dev/null +++ b/src/mlia/devices/cortexa/data_collection.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Data collection module for Cortex-A.""" +from __future__ import annotations + +import logging +from pathlib import Path + +from mlia.core.data_collection import ContextAwareDataCollector +from mlia.devices.cortexa.operators import CortexACompatibilityInfo +from mlia.devices.cortexa.operators import get_cortex_a_compatibility_info +from mlia.nn.tensorflow.config import get_tflite_model + +logger = logging.getLogger(__name__) + + +class CortexAOperatorCompatibility(ContextAwareDataCollector): + """Collect operator compatibility information.""" + + def __init__(self, model: Path) -> None: + """Init operator compatibility data collector.""" + self.model = model + + def collect_data(self) -> CortexACompatibilityInfo: + """Collect operator compatibility information.""" + tflite_model = get_tflite_model(self.model, self.context) + + logger.info("Checking operator compatibility ...") + ops = get_cortex_a_compatibility_info(Path(tflite_model.model_path)) + logger.info("Done\n") + return ops + + @classmethod + def name(cls) -> str: + """Return name of the collector.""" + return "cortex_a_operator_compatibility" diff --git a/src/mlia/devices/cortexa/events.py b/src/mlia/devices/cortexa/events.py new file mode 100644 index 0000000..dece4c7 --- /dev/null +++ b/src/mlia/devices/cortexa/events.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A MLIA module events.""" +from dataclasses import dataclass +from pathlib import Path + +from mlia.core.events import Event +from mlia.core.events import EventDispatcher +from mlia.devices.cortexa.config import CortexAConfiguration + + +@dataclass +class CortexAAdvisorStartedEvent(Event): + """Event with Cortex-A advisor parameters.""" + + model: Path + device: CortexAConfiguration + + +class CortexAAdvisorEventHandler(EventDispatcher): + """Event handler for the Cortex-A inference advisor.""" + + def on_cortex_a_advisor_started(self, event: CortexAAdvisorStartedEvent) -> None: + """Handle CortexAAdvisorStarted event.""" diff --git a/src/mlia/devices/cortexa/handlers.py b/src/mlia/devices/cortexa/handlers.py new file mode 100644 index 0000000..f54ceff --- /dev/null +++ b/src/mlia/devices/cortexa/handlers.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Event handler.""" +from __future__ import annotations + +import logging + +from mlia.core.events import CollectedDataEvent +from mlia.core.handlers import WorkflowEventsHandler +from mlia.core.typing import PathOrFileLike +from mlia.devices.cortexa.events import CortexAAdvisorEventHandler +from mlia.devices.cortexa.events import CortexAAdvisorStartedEvent +from mlia.devices.cortexa.operators import CortexACompatibilityInfo +from mlia.devices.cortexa.reporters import cortex_a_formatters + +logger = logging.getLogger(__name__) + + +class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler): + """CLI event handler.""" + + def __init__(self, output: PathOrFileLike | None = None) -> None: + """Init event handler.""" + super().__init__(cortex_a_formatters, output) + + def on_collected_data(self, event: CollectedDataEvent) -> None: + """Handle CollectedDataEvent event.""" + data_item = event.data_item + + if isinstance(data_item, CortexACompatibilityInfo): + self.reporter.submit(data_item.operators, delay_print=True) + + def on_cortex_a_advisor_started(self, event: CortexAAdvisorStartedEvent) -> None: + """Handle CortexAAdvisorStarted event.""" + self.reporter.submit(event.device) diff --git a/src/mlia/devices/cortexa/operators.py b/src/mlia/devices/cortexa/operators.py new file mode 100644 index 0000000..6a314b7 --- /dev/null +++ b/src/mlia/devices/cortexa/operators.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Cortex-A tools module.""" +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class Operator: + """Cortex-A compatibility information of the operator.""" + + name: str + location: str + is_cortex_a_compatible: bool + + +@dataclass +class CortexACompatibilityInfo: + """Model's operators.""" + + cortex_a_compatible: bool + operators: list[Operator] + + +def get_cortex_a_compatibility_info(model_path: Path) -> CortexACompatibilityInfo: + """Return list of model's operators.""" + raise NotImplementedError() diff --git a/src/mlia/devices/cortexa/reporters.py b/src/mlia/devices/cortexa/reporters.py new file mode 100644 index 0000000..076b9ca --- /dev/null +++ b/src/mlia/devices/cortexa/reporters.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Reports module.""" +from __future__ import annotations + +from typing import Any +from typing import Callable + +from mlia.core.advice_generation import Advice +from mlia.core.reporting import Report +from mlia.devices.cortexa.config import CortexAConfiguration +from mlia.devices.cortexa.operators import Operator +from mlia.utils.types import is_list_of + + +def report_device(device: CortexAConfiguration) -> Report: + """Generate report for the device.""" + raise NotImplementedError() + + +def report_advice(advice: list[Advice]) -> Report: + """Generate report for the advice.""" + raise NotImplementedError() + + +def report_cortex_a_operators(operators: list[Operator]) -> Report: + """Generate report for the operators.""" + raise NotImplementedError() + + +def cortex_a_formatters(data: Any) -> Callable[[Any], Report]: + """Find appropriate formatter for the provided data.""" + if is_list_of(data, Advice): + return report_advice + + if isinstance(data, CortexAConfiguration): + return report_device + + if is_list_of(data, Operator): + return report_cortex_a_operators + + raise Exception(f"Unable to find appropriate formatter for {data}") diff --git a/src/mlia/resources/profiles.json b/src/mlia/resources/profiles.json index 500c1ab..990cdb3 100644 --- a/src/mlia/resources/profiles.json +++ b/src/mlia/resources/profiles.json @@ -25,5 +25,8 @@ }, "tosa": { "target": "tosa" + }, + "cortex-a": { + "target": "cortex-a" } } diff --git a/tests/test_devices_cortex_a_advice_generation.py b/tests/test_devices_cortex_a_advice_generation.py new file mode 100644 index 0000000..69529d4 --- /dev/null +++ b/tests/test_devices_cortex_a_advice_generation.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for advice generation.""" +from __future__ import annotations + +import pytest + +from mlia.core.advice_generation import Advice +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.core.context import ExecutionContext +from mlia.devices.cortexa.advice_generation import CortexAAdviceProducer +from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible +from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible + + +@pytest.mark.parametrize( + "input_data, advice_category, expected_advice", + [ + [ + ModelIsNotCortexACompatible(), + AdviceCategory.OPERATORS, + [ + Advice( + [ + "Some operators in the model are not compatible with Cortex-A. " + "Please, refer to the operators table for more information." + ] + ) + ], + ], + [ + ModelIsCortexACompatible(), + AdviceCategory.OPERATORS, + [Advice(["Model is fully compatible with Cortex-A."])], + ], + ], +) +def test_cortex_a_advice_producer( + tmpdir: str, + input_data: DataItem, + advice_category: AdviceCategory, + expected_advice: list[Advice], +) -> None: + """Test Cortex-A advice producer.""" + producer = CortexAAdviceProducer() + + context = ExecutionContext( + advice_category=advice_category, + working_dir=tmpdir, + ) + + producer.set_context(context) + producer.produce_advice(input_data) + + assert producer.get_advice() == expected_advice diff --git a/tests/test_devices_cortex_a_data_analysis.py b/tests/test_devices_cortex_a_data_analysis.py new file mode 100644 index 0000000..4724c81 --- /dev/null +++ b/tests/test_devices_cortex_a_data_analysis.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Cortex-A data analysis module.""" +from __future__ import annotations + +import pytest + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.devices.cortexa.data_analysis import CortexADataAnalyzer +from mlia.devices.cortexa.data_analysis import ModelIsCortexACompatible +from mlia.devices.cortexa.data_analysis import ModelIsNotCortexACompatible +from mlia.devices.cortexa.operators import CortexACompatibilityInfo + + +@pytest.mark.parametrize( + "input_data, expected_facts", + [ + [ + CortexACompatibilityInfo(True, []), + [ModelIsCortexACompatible()], + ], + [ + CortexACompatibilityInfo(False, []), + [ModelIsNotCortexACompatible()], + ], + ], +) +def test_cortex_a_data_analyzer( + input_data: DataItem, expected_facts: list[Fact] +) -> None: + """Test Cortex-A data analyzer.""" + analyzer = CortexADataAnalyzer() + analyzer.analyze_data(input_data) + assert analyzer.get_analyzed_data() == expected_facts diff --git a/tests/test_devices_cortex_a_data_collection.py b/tests/test_devices_cortex_a_data_collection.py new file mode 100644 index 0000000..7ea3e52 --- /dev/null +++ b/tests/test_devices_cortex_a_data_collection.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Cortex-A data collection module.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.core.context import ExecutionContext +from mlia.devices.cortexa.data_collection import CortexAOperatorCompatibility +from mlia.devices.cortexa.operators import CortexACompatibilityInfo + + +def test_cortex_a_data_collection( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str +) -> None: + """Test Cortex-A data collection.""" + monkeypatch.setattr( + "mlia.devices.cortexa.data_collection.get_cortex_a_compatibility_info", + MagicMock(return_value=CortexACompatibilityInfo(True, [])), + ) + context = ExecutionContext(working_dir=tmpdir) + collector = CortexAOperatorCompatibility(test_tflite_model) + collector.set_context(context) + + data_item = collector.collect_data() + + assert isinstance(data_item, CortexACompatibilityInfo) diff --git a/tests/test_devices_cortexa_advisor.py b/tests/test_devices_cortexa_advisor.py new file mode 100644 index 0000000..8cd60d6 --- /dev/null +++ b/tests/test_devices_cortexa_advisor.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Cortex-A MLIA module.""" +from pathlib import Path + +from mlia.core.context import ExecutionContext +from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.devices.cortexa.advisor import configure_and_get_cortexa_advisor +from mlia.devices.cortexa.advisor import CortexAInferenceAdvisor + + +def test_advisor_metadata() -> None: + """Test advisor metadata.""" + assert CortexAInferenceAdvisor.name() == "cortex_a_inference_advisor" + + +def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: + """Test Cortex-A advisor configuration.""" + ctx = ExecutionContext() + + advisor = configure_and_get_cortexa_advisor(ctx, "cortex-a", test_tflite_model) + workflow = advisor.configure(ctx) + + assert isinstance(advisor, CortexAInferenceAdvisor) + + assert ctx.event_handlers is not None + assert ctx.config_parameters == { + "cortex_a_inference_advisor": { + "model": str(test_tflite_model), + "target_profile": "cortex-a", + } + } + + assert isinstance(workflow, DefaultWorkflowExecutor) diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py index b31b4ff..9dd51e1 100644 --- a/tests/test_utils_filesystem.py +++ b/tests/test_utils_filesystem.py @@ -48,6 +48,7 @@ def test_profiles_data() -> None: "ethos-u65-512", "ethos-u65-256", "tosa", + "cortex-a", ] @@ -76,6 +77,7 @@ def test_get_supported_profile_names() -> None: "ethos-u65-512", "ethos-u65-256", "tosa", + "cortex-a", ] |