aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-07-21 14:06:50 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-08-19 10:23:23 +0100
commit664d8c55609253e68d153a91514c8fefa00557b1 (patch)
tree4b2a0ecaf30e9151d6b971a24fa6c6104884896f /tests
parenta8ee1aee3e674c78a77801d1bf2256881ab6b4b9 (diff)
downloadmlia-664d8c55609253e68d153a91514c8fefa00557b1.tar.gz
MLIA-549 Integrate TOSA checker into MLIA
- Add new module for TOSA - Add advisor workflow components - Use TOSA checker for getting operators compatibility information Change-Id: I769e5e2a84e15779658f0895b4a347384def63bf
Diffstat (limited to 'tests')
-rw-r--r--tests/test_api.py4
-rw-r--r--tests/test_devices_tosa_advice_generation.py56
-rw-r--r--tests/test_devices_tosa_advisor.py29
-rw-r--r--tests/test_devices_tosa_data_analysis.py33
-rw-r--r--tests/test_devices_tosa_data_collection.py28
-rw-r--r--tests/test_devices_tosa_operators.py84
-rw-r--r--tests/test_utils_filesystem.py2
7 files changed, 236 insertions, 0 deletions
diff --git a/tests/test_api.py b/tests/test_api.py
index e8df7af..7b567bf 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -12,6 +12,7 @@ from mlia.core.common import AdviceCategory
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
+from mlia.devices.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
@@ -103,3 +104,6 @@ def test_get_advisor(
ExecutionContext(), "ethos-u55-256", str(test_keras_model)
)
assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor)
+
+ tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model))
+ assert isinstance(tosa_advisor, TOSAInferenceAdvisor)
diff --git a/tests/test_devices_tosa_advice_generation.py b/tests/test_devices_tosa_advice_generation.py
new file mode 100644
index 0000000..018ba57
--- /dev/null
+++ b/tests/test_devices_tosa_advice_generation.py
@@ -0,0 +1,56 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for advice generation."""
+from typing import List
+
+import pytest
+
+from mlia.core.advice_generation import Advice
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import ExecutionContext
+from mlia.devices.tosa.advice_generation import TOSAAdviceProducer
+from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
+from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
+
+
+@pytest.mark.parametrize(
+ "input_data, advice_category, expected_advice",
+ [
+ [
+ ModelIsNotTOSACompatible(),
+ AdviceCategory.OPERATORS,
+ [
+ Advice(
+ [
+ "Some operators in the model are not TOSA compatible. "
+ "Please, refer to the operators table for more information."
+ ]
+ )
+ ],
+ ],
+ [
+ ModelIsTOSACompatible(),
+ AdviceCategory.OPERATORS,
+ [Advice(["Model is fully TOSA compatible."])],
+ ],
+ ],
+)
+def test_tosa_advice_producer(
+ tmpdir: str,
+ input_data: DataItem,
+ advice_category: AdviceCategory,
+ expected_advice: List[Advice],
+) -> None:
+ """Test TOSA advice producer."""
+ producer = TOSAAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ )
+
+ producer.set_context(context)
+ producer.produce_advice(input_data)
+
+ assert producer.get_advice() == expected_advice
diff --git a/tests/test_devices_tosa_advisor.py b/tests/test_devices_tosa_advisor.py
new file mode 100644
index 0000000..1c7a31a
--- /dev/null
+++ b/tests/test_devices_tosa_advisor.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA advisor."""
+from pathlib import Path
+
+from mlia.core.context import ExecutionContext
+from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor
+from mlia.devices.tosa.advisor import TOSAInferenceAdvisor
+
+
+def test_configure_and_get_tosa_advisor(test_tflite_model: Path) -> None:
+ """Test TOSA advisor configuration."""
+ ctx = ExecutionContext()
+
+ advisor = configure_and_get_tosa_advisor(ctx, "tosa", test_tflite_model)
+ workflow = advisor.configure(ctx)
+
+ assert isinstance(advisor, TOSAInferenceAdvisor)
+
+ assert ctx.event_handlers is not None
+ assert ctx.config_parameters == {
+ "tosa_inference_advisor": {
+ "model": str(test_tflite_model),
+ "target_profile": "tosa",
+ }
+ }
+
+ assert isinstance(workflow, DefaultWorkflowExecutor)
diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py
new file mode 100644
index 0000000..60bcee8
--- /dev/null
+++ b/tests/test_devices_tosa_data_analysis.py
@@ -0,0 +1,33 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA data analysis module."""
+from typing import List
+
+import pytest
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible
+from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible
+from mlia.devices.tosa.data_analysis import TOSADataAnalyzer
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_facts",
+ [
+ [
+ TOSACompatibilityInfo(True, []),
+ [ModelIsTOSACompatible()],
+ ],
+ [
+ TOSACompatibilityInfo(False, []),
+ [ModelIsNotTOSACompatible()],
+ ],
+ ],
+)
+def test_tosa_data_analyzer(input_data: DataItem, expected_facts: List[Fact]) -> None:
+ """Test TOSA data analyzer."""
+ analyzer = TOSADataAnalyzer()
+ analyzer.analyze_data(input_data)
+ assert analyzer.get_analyzed_data() == expected_facts
diff --git a/tests/test_devices_tosa_data_collection.py b/tests/test_devices_tosa_data_collection.py
new file mode 100644
index 0000000..b9c0b4c
--- /dev/null
+++ b/tests/test_devices_tosa_data_collection.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA data collection module."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+def test_tosa_data_collection(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str
+) -> None:
+ """Test TOSA data collection."""
+ monkeypatch.setattr(
+ "mlia.devices.tosa.data_collection.get_tosa_compatibility_info",
+ MagicMock(return_value=TOSACompatibilityInfo(True, [])),
+ )
+ context = ExecutionContext(working_dir=tmpdir)
+ collector = TOSAOperatorCompatibility(test_tflite_model)
+ collector.set_context(context)
+
+ data_item = collector.collect_data()
+
+ assert isinstance(data_item, TOSACompatibilityInfo)
diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py
new file mode 100644
index 0000000..b7736d2
--- /dev/null
+++ b/tests/test_devices_tosa_operators.py
@@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for TOSA compatibility."""
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any
+from typing import Optional
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.tosa.operators import get_tosa_compatibility_info
+from mlia.devices.tosa.operators import Operator
+from mlia.devices.tosa.operators import TOSACompatibilityInfo
+
+
+def replace_get_tosa_checker_with_mock(
+ monkeypatch: pytest.MonkeyPatch, mock: Optional[MagicMock]
+) -> None:
+ """Replace TOSA checker with mock."""
+ monkeypatch.setattr(
+ "mlia.devices.tosa.operators.get_tosa_checker", MagicMock(return_value=mock)
+ )
+
+
+def test_compatibility_check_should_fail_if_checker_not_available(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path
+) -> None:
+ """Test that compatibility check should fail if TOSA checker is not available."""
+ replace_get_tosa_checker_with_mock(monkeypatch, None)
+
+ with pytest.raises(Exception, match="TOSA checker is not available"):
+ get_tosa_compatibility_info(test_tflite_model)
+
+
+@pytest.mark.parametrize(
+ "is_tosa_compatible, operators, expected_result",
+ [
+ [
+ True,
+ [],
+ TOSACompatibilityInfo(True, []),
+ ],
+ [
+ True,
+ [
+ SimpleNamespace(
+ location="op_location",
+ name="op_name",
+ is_tosa_compatible=True,
+ )
+ ],
+ TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]),
+ ],
+ [
+ False,
+ [
+ SimpleNamespace(
+ location="op_location",
+ name="op_name",
+ is_tosa_compatible=False,
+ )
+ ],
+ TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]),
+ ],
+ ],
+)
+def test_get_tosa_compatibility_info(
+ monkeypatch: pytest.MonkeyPatch,
+ test_tflite_model: Path,
+ is_tosa_compatible: bool,
+ operators: Any,
+ expected_result: TOSACompatibilityInfo,
+) -> None:
+ """Test getting TOSA compatibility information."""
+ mock_checker = MagicMock()
+ mock_checker.is_tosa_compatible.return_value = is_tosa_compatible
+ mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access
+ operators
+ )
+
+ replace_get_tosa_checker_with_mock(monkeypatch, mock_checker)
+
+ assert get_tosa_compatibility_info(test_tflite_model) == expected_result
diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py
index 7cf32e7..fb894db 100644
--- a/tests/test_utils_filesystem.py
+++ b/tests/test_utils_filesystem.py
@@ -46,6 +46,7 @@ def test_profiles_data() -> None:
"ethos-u55-256",
"ethos-u55-128",
"ethos-u65-512",
+ "tosa",
]
@@ -72,6 +73,7 @@ def test_get_supported_profile_names() -> None:
"ethos-u55-256",
"ethos-u55-128",
"ethos-u65-512",
+ "tosa",
]