aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_core_workflow.py
blob: 470e572ae2135abf7afc5e7ad73302aebae93d08 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module workflow."""
from dataclasses import dataclass
from unittest.mock import call
from unittest.mock import MagicMock

from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
from mlia.core.advice_generation import ContextAwareAdviceProducer
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import ContextAwareDataAnalyzer
from mlia.core.data_collection import ContextAwareDataCollector
from mlia.core.errors import FunctionalityNotSupportedError
from mlia.core.events import AdviceStageFinishedEvent
from mlia.core.events import AdviceStageStartedEvent
from mlia.core.events import AnalyzedDataEvent
from mlia.core.events import CollectedDataEvent
from mlia.core.events import DataAnalysisStageFinishedEvent
from mlia.core.events import DataAnalysisStageStartedEvent
from mlia.core.events import DataCollectionStageFinishedEvent
from mlia.core.events import DataCollectionStageStartedEvent
from mlia.core.events import DataCollectorSkippedEvent
from mlia.core.events import DefaultEventPublisher
from mlia.core.events import Event
from mlia.core.events import EventHandler
from mlia.core.events import ExecutionFailedEvent
from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
from mlia.core.workflow import DefaultWorkflowExecutor


@dataclass
class SampleEvent(Event):
    """Sample event."""

    msg: str


def test_workflow_executor(tmpdir: str) -> None:
    """Test workflow executor."""
    handler_mock = MagicMock(spec=EventHandler)
    data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
    data_collector_mock.collect_data.return_value = 42

    data_collector_mock_no_value = MagicMock(spec=ContextAwareDataCollector)
    data_collector_mock_no_value.collect_data.return_value = None

    data_collector_mock_skipped = MagicMock(spec=ContextAwareDataCollector)
    data_collector_mock_skipped.name.return_value = "skipped_collector"
    data_collector_mock_skipped.collect_data.side_effect = (
        FunctionalityNotSupportedError("Error!", "Error!")
    )

    data_analyzer_mock = MagicMock(spec=ContextAwareDataAnalyzer)
    data_analyzer_mock.get_analyzed_data.return_value = ["Really good number!"]

    advice_producer_mock1 = MagicMock(spec=ContextAwareAdviceProducer)
    advice_producer_mock1.get_advice.return_value = Advice(["All good!"])

    advice_producer_mock2 = MagicMock(spec=ContextAwareAdviceProducer)
    advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])]

    context = ExecutionContext(
        working_dir=tmpdir,
        event_handlers=[handler_mock],
        event_publisher=DefaultEventPublisher(),
    )

    executor = DefaultWorkflowExecutor(
        context,
        [
            data_collector_mock,
            data_collector_mock_no_value,
            data_collector_mock_skipped,
        ],
        [data_analyzer_mock],
        [
            advice_producer_mock1,
            advice_producer_mock2,
        ],
        [SampleEvent("Hello from advisor!")],
    )

    executor.run()

    data_collector_mock.collect_data.assert_called_once()
    data_collector_mock_no_value.collect_data.assert_called_once()
    data_collector_mock_skipped.collect_data.assert_called_once()

    data_analyzer_mock.analyze_data.assert_called_once_with(42)

    advice_producer_mock1.produce_advice.assert_called_once_with("Really good number!")
    advice_producer_mock1.get_advice.assert_called_once()

    advice_producer_mock2.produce_advice.called_once_with("Really good number!")
    advice_producer_mock2.get_advice.assert_called_once()

    expected_mock_calls = [
        call(ExecutionStartedEvent()),
        call(SampleEvent("Hello from advisor!")),
        call(DataCollectionStageStartedEvent()),
        call(CollectedDataEvent(data_item=42)),
        call(DataCollectorSkippedEvent("skipped_collector", "Error!: Error!")),
        call(DataCollectionStageFinishedEvent()),
        call(DataAnalysisStageStartedEvent()),
        call(AnalyzedDataEvent(data_item="Really good number!")),
        call(DataAnalysisStageFinishedEvent()),
        call(AdviceStageStartedEvent()),
        call(AdviceEvent(advice=Advice(messages=["All good!"]))),
        call(AdviceEvent(advice=Advice(messages=["Good advice!"]))),
        call(AdviceStageFinishedEvent()),
        call(ExecutionFinishedEvent()),
    ]

    for expected_call, actual_call in zip(
        expected_mock_calls, handler_mock.handle_event.mock_calls
    ):
        expected_event = expected_call.args[0]
        actual_event = actual_call.args[0]

        assert actual_event.compare_without_id(expected_event)


def test_workflow_executor_failed(tmpdir: str) -> None:
    """Test scenario when one of the components raises exception."""
    handler_mock = MagicMock(spec=EventHandler)

    context = ExecutionContext(
        working_dir=tmpdir,
        event_handlers=[handler_mock],
        event_publisher=DefaultEventPublisher(),
    )

    collection_exception = Exception("Collection failed")

    data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
    data_collector_mock.collect_data.side_effect = collection_exception

    executor = DefaultWorkflowExecutor(context, [data_collector_mock], [], [])
    executor.run()

    expected_mock_calls = [
        call(ExecutionStartedEvent()),
        call(DataCollectionStageStartedEvent()),
        call(ExecutionFailedEvent(collection_exception)),
    ]

    for expected_call, actual_call in zip(
        expected_mock_calls, handler_mock.handle_event.mock_calls
    ):
        expected_event = expected_call.args[0]
        actual_event = actual_call.args[0]

        if isinstance(actual_event, ExecutionFailedEvent):
            # seems that dataclass comparison doesn't work well
            # for the exceptions
            actual_exception = actual_event.err
            expected_exception = expected_event.err

            assert actual_exception == expected_exception
            continue

        assert actual_event.compare_without_id(expected_event)