aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_core_events.py
blob: faaab7c5f8ed7475675d4aea0229a44dd83a1c38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module events."""
from dataclasses import dataclass
from unittest.mock import call
from unittest.mock import MagicMock

import pytest

from mlia.core.events import action
from mlia.core.events import ActionFinishedEvent
from mlia.core.events import ActionStartedEvent
from mlia.core.events import DebugEventHandler
from mlia.core.events import DefaultEventPublisher
from mlia.core.events import Event
from mlia.core.events import EventDispatcher
from mlia.core.events import EventHandler
from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
from mlia.core.events import stage
from mlia.core.events import SystemEventsHandler


@dataclass
class SampleEvent(Event):
    """Sample event."""

    msg: str


def test_event_publisher() -> None:
    """Test event publishing."""
    publisher = DefaultEventPublisher()
    handler_mock1 = MagicMock(spec=EventHandler)
    handler_mock2 = MagicMock(spec=EventHandler)

    publisher.register_event_handlers([handler_mock1, handler_mock2])

    event = SampleEvent("hello, event!")
    publisher.publish_event(event)

    handler_mock1.handle_event.assert_called_once_with(event)
    handler_mock2.handle_event.assert_called_once_with(event)


def test_stage_context_manager() -> None:
    """Test stage context manager."""
    publisher = DefaultEventPublisher()

    handler_mock = MagicMock(spec=EventHandler)
    publisher.register_event_handler(handler_mock)

    events = (SampleEvent("hello"), SampleEvent("goodbye"))
    with stage(publisher, events):
        print("perform actions")

    assert handler_mock.handle_event.call_count == 2
    calls = [call(event) for event in events]
    handler_mock.handle_event.assert_has_calls(calls)


def test_action_context_manager() -> None:
    """Test action stage context manager."""
    publisher = DefaultEventPublisher()

    handler_mock = MagicMock(spec=EventHandler)
    publisher.register_event_handler(handler_mock)

    with action(publisher, "Sample action"):
        print("perform actions")

    assert handler_mock.handle_event.call_count == 2
    calls = handler_mock.handle_event.mock_calls

    action_started = calls[0].args[0]
    action_finished = calls[1].args[0]

    assert isinstance(action_started, ActionStartedEvent)
    assert isinstance(action_finished, ActionFinishedEvent)

    assert action_finished.parent_event_id == action_started.event_id


def test_debug_event_handler(capsys: pytest.CaptureFixture) -> None:
    """Test debugging event handler."""
    publisher = DefaultEventPublisher()

    publisher.register_event_handler(DebugEventHandler())
    publisher.register_event_handler(DebugEventHandler(with_stacktrace=True))

    messages = ["Sample event 1", "Sample event 2"]
    for message in messages:
        publisher.publish_event(SampleEvent(message))

    captured = capsys.readouterr()
    for message in messages:
        assert message in captured.out

    assert "traceback.print_stack" in captured.err


def test_event_dispatcher(capsys: pytest.CaptureFixture) -> None:
    """Test event dispatcher."""

    class SampleEventHandler(EventDispatcher):
        """Sample event handler."""

        def on_sample_event(  # pylint: disable=no-self-use
            self, _event: SampleEvent
        ) -> None:
            """Event handler for SampleEvent."""
            print("Got sample event")

    publisher = DefaultEventPublisher()
    publisher.register_event_handler(SampleEventHandler())
    publisher.publish_event(SampleEvent("Sample event"))

    captured = capsys.readouterr()
    assert captured.out.strip() == "Got sample event"


def test_system_events_handler(capsys: pytest.CaptureFixture) -> None:
    """Test system events handler."""

    class CustomSystemEventHandler(SystemEventsHandler):
        """Custom system event handler."""

        def on_execution_started(self, event: ExecutionStartedEvent) -> None:
            """Handle ExecutionStarted event."""
            print("Execution started")

        def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
            """Handle ExecutionFinished event."""
            print("Execution finished")

    publisher = DefaultEventPublisher()
    publisher.register_event_handler(CustomSystemEventHandler())

    publisher.publish_event(ExecutionStartedEvent())
    publisher.publish_event(SampleEvent("Hello world!"))
    publisher.publish_event(ExecutionFinishedEvent())

    captured = capsys.readouterr()
    assert captured.out.strip() == "Execution started\nExecution finished"


def test_compare_without_id() -> None:
    """Test event comparison without event_id."""
    event1 = SampleEvent("message")
    event2 = SampleEvent("message")

    assert event1 != event2
    assert event1.compare_without_id(event2)

    assert not event1.compare_without_id("message")  # type: ignore