aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_core_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/mlia/test_core_events.py')
-rw-r--r--tests/mlia/test_core_events.py155
1 files changed, 155 insertions, 0 deletions
diff --git a/tests/mlia/test_core_events.py b/tests/mlia/test_core_events.py
new file mode 100644
index 0000000..faaab7c
--- /dev/null
+++ b/tests/mlia/test_core_events.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module events."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.events import action
+from mlia.core.events import ActionFinishedEvent
+from mlia.core.events import ActionStartedEvent
+from mlia.core.events import DebugEventHandler
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import stage
+from mlia.core.events import SystemEventsHandler
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_event_publisher() -> None:
+ """Test event publishing."""
+ publisher = DefaultEventPublisher()
+ handler_mock1 = MagicMock(spec=EventHandler)
+ handler_mock2 = MagicMock(spec=EventHandler)
+
+ publisher.register_event_handlers([handler_mock1, handler_mock2])
+
+ event = SampleEvent("hello, event!")
+ publisher.publish_event(event)
+
+ handler_mock1.handle_event.assert_called_once_with(event)
+ handler_mock2.handle_event.assert_called_once_with(event)
+
+
+def test_stage_context_manager() -> None:
+ """Test stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ events = (SampleEvent("hello"), SampleEvent("goodbye"))
+ with stage(publisher, events):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = [call(event) for event in events]
+ handler_mock.handle_event.assert_has_calls(calls)
+
+
+def test_action_context_manager() -> None:
+ """Test action stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ with action(publisher, "Sample action"):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = handler_mock.handle_event.mock_calls
+
+ action_started = calls[0].args[0]
+ action_finished = calls[1].args[0]
+
+ assert isinstance(action_started, ActionStartedEvent)
+ assert isinstance(action_finished, ActionFinishedEvent)
+
+ assert action_finished.parent_event_id == action_started.event_id
+
+
+def test_debug_event_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test debugging event handler."""
+ publisher = DefaultEventPublisher()
+
+ publisher.register_event_handler(DebugEventHandler())
+ publisher.register_event_handler(DebugEventHandler(with_stacktrace=True))
+
+ messages = ["Sample event 1", "Sample event 2"]
+ for message in messages:
+ publisher.publish_event(SampleEvent(message))
+
+ captured = capsys.readouterr()
+ for message in messages:
+ assert message in captured.out
+
+ assert "traceback.print_stack" in captured.err
+
+
+def test_event_dispatcher(capsys: pytest.CaptureFixture) -> None:
+ """Test event dispatcher."""
+
+ class SampleEventHandler(EventDispatcher):
+ """Sample event handler."""
+
+ def on_sample_event( # pylint: disable=no-self-use
+ self, _event: SampleEvent
+ ) -> None:
+ """Event handler for SampleEvent."""
+ print("Got sample event")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(SampleEventHandler())
+ publisher.publish_event(SampleEvent("Sample event"))
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Got sample event"
+
+
+def test_system_events_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test system events handler."""
+
+ class CustomSystemEventHandler(SystemEventsHandler):
+ """Custom system event handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ print("Execution started")
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+ print("Execution finished")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(CustomSystemEventHandler())
+
+ publisher.publish_event(ExecutionStartedEvent())
+ publisher.publish_event(SampleEvent("Hello world!"))
+ publisher.publish_event(ExecutionFinishedEvent())
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Execution started\nExecution finished"
+
+
+def test_compare_without_id() -> None:
+ """Test event comparison without event_id."""
+ event1 = SampleEvent("message")
+ event2 = SampleEvent("message")
+
+ assert event1 != event2
+ assert event1.compare_without_id(event2)
+
+ assert not event1.compare_without_id("message") # type: ignore