diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_api.py | 4 | ||||
-rw-r--r-- | tests/test_devices_tosa_advice_generation.py | 56 | ||||
-rw-r--r-- | tests/test_devices_tosa_advisor.py | 29 | ||||
-rw-r--r-- | tests/test_devices_tosa_data_analysis.py | 33 | ||||
-rw-r--r-- | tests/test_devices_tosa_data_collection.py | 28 | ||||
-rw-r--r-- | tests/test_devices_tosa_operators.py | 84 | ||||
-rw-r--r-- | tests/test_utils_filesystem.py | 2 |
7 files changed, 236 insertions, 0 deletions
diff --git a/tests/test_api.py b/tests/test_api.py index e8df7af..7b567bf 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -12,6 +12,7 @@ from mlia.core.common import AdviceCategory from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor +from mlia.devices.tosa.advisor import TOSAInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: @@ -103,3 +104,6 @@ def test_get_advisor( ExecutionContext(), "ethos-u55-256", str(test_keras_model) ) assert isinstance(ethos_u55_advisor, EthosUInferenceAdvisor) + + tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model)) + assert isinstance(tosa_advisor, TOSAInferenceAdvisor) diff --git a/tests/test_devices_tosa_advice_generation.py b/tests/test_devices_tosa_advice_generation.py new file mode 100644 index 0000000..018ba57 --- /dev/null +++ b/tests/test_devices_tosa_advice_generation.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for advice generation.""" +from typing import List + +import pytest + +from mlia.core.advice_generation import Advice +from mlia.core.common import AdviceCategory +from mlia.core.common import DataItem +from mlia.core.context import ExecutionContext +from mlia.devices.tosa.advice_generation import TOSAAdviceProducer +from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible +from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible + + +@pytest.mark.parametrize( + "input_data, advice_category, expected_advice", + [ + [ + ModelIsNotTOSACompatible(), + AdviceCategory.OPERATORS, + [ + Advice( + [ + "Some operators in the model are not TOSA compatible. " + "Please, refer to the operators table for more information." + ] + ) + ], + ], + [ + ModelIsTOSACompatible(), + AdviceCategory.OPERATORS, + [Advice(["Model is fully TOSA compatible."])], + ], + ], +) +def test_tosa_advice_producer( + tmpdir: str, + input_data: DataItem, + advice_category: AdviceCategory, + expected_advice: List[Advice], +) -> None: + """Test TOSA advice producer.""" + producer = TOSAAdviceProducer() + + context = ExecutionContext( + advice_category=advice_category, + working_dir=tmpdir, + ) + + producer.set_context(context) + producer.produce_advice(input_data) + + assert producer.get_advice() == expected_advice diff --git a/tests/test_devices_tosa_advisor.py b/tests/test_devices_tosa_advisor.py new file mode 100644 index 0000000..1c7a31a --- /dev/null +++ b/tests/test_devices_tosa_advisor.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA advisor.""" +from pathlib import Path + +from mlia.core.context import ExecutionContext +from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.devices.tosa.advisor import configure_and_get_tosa_advisor +from mlia.devices.tosa.advisor import TOSAInferenceAdvisor + + +def test_configure_and_get_tosa_advisor(test_tflite_model: Path) -> None: + """Test TOSA advisor configuration.""" + ctx = ExecutionContext() + + advisor = configure_and_get_tosa_advisor(ctx, "tosa", test_tflite_model) + workflow = advisor.configure(ctx) + + assert isinstance(advisor, TOSAInferenceAdvisor) + + assert ctx.event_handlers is not None + assert ctx.config_parameters == { + "tosa_inference_advisor": { + "model": str(test_tflite_model), + "target_profile": "tosa", + } + } + + assert isinstance(workflow, DefaultWorkflowExecutor) diff --git a/tests/test_devices_tosa_data_analysis.py b/tests/test_devices_tosa_data_analysis.py new file mode 100644 index 0000000..60bcee8 --- /dev/null +++ b/tests/test_devices_tosa_data_analysis.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA data analysis module.""" +from typing import List + +import pytest + +from mlia.core.common import DataItem +from mlia.core.data_analysis import Fact +from mlia.devices.tosa.data_analysis import ModelIsNotTOSACompatible +from mlia.devices.tosa.data_analysis import ModelIsTOSACompatible +from mlia.devices.tosa.data_analysis import TOSADataAnalyzer +from mlia.devices.tosa.operators import TOSACompatibilityInfo + + +@pytest.mark.parametrize( + "input_data, expected_facts", + [ + [ + TOSACompatibilityInfo(True, []), + [ModelIsTOSACompatible()], + ], + [ + TOSACompatibilityInfo(False, []), + [ModelIsNotTOSACompatible()], + ], + ], +) +def test_tosa_data_analyzer(input_data: DataItem, expected_facts: List[Fact]) -> None: + """Test TOSA data analyzer.""" + analyzer = TOSADataAnalyzer() + analyzer.analyze_data(input_data) + assert analyzer.get_analyzed_data() == expected_facts diff --git a/tests/test_devices_tosa_data_collection.py b/tests/test_devices_tosa_data_collection.py new file mode 100644 index 0000000..b9c0b4c --- /dev/null +++ b/tests/test_devices_tosa_data_collection.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA data collection module.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.core.context import ExecutionContext +from mlia.devices.tosa.data_collection import TOSAOperatorCompatibility +from mlia.devices.tosa.operators import TOSACompatibilityInfo + + +def test_tosa_data_collection( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str +) -> None: + """Test TOSA data collection.""" + monkeypatch.setattr( + "mlia.devices.tosa.data_collection.get_tosa_compatibility_info", + MagicMock(return_value=TOSACompatibilityInfo(True, [])), + ) + context = ExecutionContext(working_dir=tmpdir) + collector = TOSAOperatorCompatibility(test_tflite_model) + collector.set_context(context) + + data_item = collector.collect_data() + + assert isinstance(data_item, TOSACompatibilityInfo) diff --git a/tests/test_devices_tosa_operators.py b/tests/test_devices_tosa_operators.py new file mode 100644 index 0000000..b7736d2 --- /dev/null +++ b/tests/test_devices_tosa_operators.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TOSA compatibility.""" +from pathlib import Path +from types import SimpleNamespace +from typing import Any +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +from mlia.devices.tosa.operators import get_tosa_compatibility_info +from mlia.devices.tosa.operators import Operator +from mlia.devices.tosa.operators import TOSACompatibilityInfo + + +def replace_get_tosa_checker_with_mock( + monkeypatch: pytest.MonkeyPatch, mock: Optional[MagicMock] +) -> None: + """Replace TOSA checker with mock.""" + monkeypatch.setattr( + "mlia.devices.tosa.operators.get_tosa_checker", MagicMock(return_value=mock) + ) + + +def test_compatibility_check_should_fail_if_checker_not_available( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path +) -> None: + """Test that compatibility check should fail if TOSA checker is not available.""" + replace_get_tosa_checker_with_mock(monkeypatch, None) + + with pytest.raises(Exception, match="TOSA checker is not available"): + get_tosa_compatibility_info(test_tflite_model) + + +@pytest.mark.parametrize( + "is_tosa_compatible, operators, expected_result", + [ + [ + True, + [], + TOSACompatibilityInfo(True, []), + ], + [ + True, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=True, + ) + ], + TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]), + ], + [ + False, + [ + SimpleNamespace( + location="op_location", + name="op_name", + is_tosa_compatible=False, + ) + ], + TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]), + ], + ], +) +def test_get_tosa_compatibility_info( + monkeypatch: pytest.MonkeyPatch, + test_tflite_model: Path, + is_tosa_compatible: bool, + operators: Any, + expected_result: TOSACompatibilityInfo, +) -> None: + """Test getting TOSA compatibility information.""" + mock_checker = MagicMock() + mock_checker.is_tosa_compatible.return_value = is_tosa_compatible + mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access + operators + ) + + replace_get_tosa_checker_with_mock(monkeypatch, mock_checker) + + assert get_tosa_compatibility_info(test_tflite_model) == expected_result diff --git a/tests/test_utils_filesystem.py b/tests/test_utils_filesystem.py index 7cf32e7..fb894db 100644 --- a/tests/test_utils_filesystem.py +++ b/tests/test_utils_filesystem.py @@ -46,6 +46,7 @@ def test_profiles_data() -> None: "ethos-u55-256", "ethos-u55-128", "ethos-u65-512", + "tosa", ] @@ -72,6 +73,7 @@ def test_get_supported_profile_names() -> None: "ethos-u55-256", "ethos-u55-128", "ethos-u65-512", + "tosa", ] |