aboutsummaryrefslogtreecommitdiff
path: root/tests/test_core_workflow.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-11 12:33:42 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-26 14:08:21 +0100
commit5d81f37de09efe10f90512e50252be9c36925fcf (patch)
treeb4d7cdfd051da0a6e882bdfcf280fd7ca7b39e57 /tests/test_core_workflow.py
parent7899b908c1fe6d86b92a80f3827ddd0ac05b674b (diff)
downloadmlia-5d81f37de09efe10f90512e50252be9c36925fcf.tar.gz
MLIA-551 Rework remains of AIET architecture
Re-factoring the code base to further merge the old AIET code into MLIA. - Remove last traces of the backend type 'tool' - Controlled systems removed, including SSH protocol, controller, RunningCommand, locks etc. - Build command / build dir and deploy functionality removed from Applications and Systems - Moving working_dir() - Replace module 'output_parser' with new module 'output_consumer' and merge Base64 parsing into it - Change the output consumption to optionally remove (i.e. actually consume) lines - Use Base64 parsing in GenericInferenceOutputParser, replacing the regex-based parsing and remove the now unused regex parsing - Remove AIET reporting - Pre-install applications by moving them to src/mlia/resources/backends - Rename aiet-config.json to backend-config.json - Move tests from tests/mlia/ to tests/ - Adapt unit tests to code changes - Dependencies removed: paramiko, filelock, psutil - Fix bug in corstone.py: The wrong resource directory was used which broke the functionality to download backends. - Use f-string formatting. - Use logging instead of print. Change-Id: I768bc3bb6b2eda57d219ad01be4a8e0a74167d76
Diffstat (limited to 'tests/test_core_workflow.py')
-rw-r--r--tests/test_core_workflow.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/tests/test_core_workflow.py b/tests/test_core_workflow.py
new file mode 100644
index 0000000..470e572
--- /dev/null
+++ b/tests/test_core_workflow.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module workflow."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.advice_generation import ContextAwareAdviceProducer
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import ContextAwareDataAnalyzer
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.workflow import DefaultWorkflowExecutor
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_workflow_executor(tmpdir: str) -> None:
+ """Test workflow executor."""
+ handler_mock = MagicMock(spec=EventHandler)
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.return_value = 42
+
+ data_collector_mock_no_value = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_no_value.collect_data.return_value = None
+
+ data_collector_mock_skipped = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_skipped.name.return_value = "skipped_collector"
+ data_collector_mock_skipped.collect_data.side_effect = (
+ FunctionalityNotSupportedError("Error!", "Error!")
+ )
+
+ data_analyzer_mock = MagicMock(spec=ContextAwareDataAnalyzer)
+ data_analyzer_mock.get_analyzed_data.return_value = ["Really good number!"]
+
+ advice_producer_mock1 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock1.get_advice.return_value = Advice(["All good!"])
+
+ advice_producer_mock2 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])]
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ executor = DefaultWorkflowExecutor(
+ context,
+ [
+ data_collector_mock,
+ data_collector_mock_no_value,
+ data_collector_mock_skipped,
+ ],
+ [data_analyzer_mock],
+ [
+ advice_producer_mock1,
+ advice_producer_mock2,
+ ],
+ [SampleEvent("Hello from advisor!")],
+ )
+
+ executor.run()
+
+ data_collector_mock.collect_data.assert_called_once()
+ data_collector_mock_no_value.collect_data.assert_called_once()
+ data_collector_mock_skipped.collect_data.assert_called_once()
+
+ data_analyzer_mock.analyze_data.assert_called_once_with(42)
+
+ advice_producer_mock1.produce_advice.assert_called_once_with("Really good number!")
+ advice_producer_mock1.get_advice.assert_called_once()
+
+ advice_producer_mock2.produce_advice.called_once_with("Really good number!")
+ advice_producer_mock2.get_advice.assert_called_once()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(SampleEvent("Hello from advisor!")),
+ call(DataCollectionStageStartedEvent()),
+ call(CollectedDataEvent(data_item=42)),
+ call(DataCollectorSkippedEvent("skipped_collector", "Error!: Error!")),
+ call(DataCollectionStageFinishedEvent()),
+ call(DataAnalysisStageStartedEvent()),
+ call(AnalyzedDataEvent(data_item="Really good number!")),
+ call(DataAnalysisStageFinishedEvent()),
+ call(AdviceStageStartedEvent()),
+ call(AdviceEvent(advice=Advice(messages=["All good!"]))),
+ call(AdviceEvent(advice=Advice(messages=["Good advice!"]))),
+ call(AdviceStageFinishedEvent()),
+ call(ExecutionFinishedEvent()),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ assert actual_event.compare_without_id(expected_event)
+
+
+def test_workflow_executor_failed(tmpdir: str) -> None:
+ """Test scenario when one of the components raises exception."""
+ handler_mock = MagicMock(spec=EventHandler)
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ collection_exception = Exception("Collection failed")
+
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.side_effect = collection_exception
+
+ executor = DefaultWorkflowExecutor(context, [data_collector_mock], [], [])
+ executor.run()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(DataCollectionStageStartedEvent()),
+ call(ExecutionFailedEvent(collection_exception)),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ if isinstance(actual_event, ExecutionFailedEvent):
+ # seems that dataclass comparison doesn't work well
+ # for the exceptions
+ actual_exception = actual_event.err
+ expected_exception = expected_event.err
+
+ assert actual_exception == expected_exception
+ continue
+
+ assert actual_event.compare_without_id(expected_event)