aboutsummaryrefslogtreecommitdiff
path: root/tests/test_core_reporting.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_core_reporting.py')
-rw-r--r--tests/test_core_reporting.py105
1 files changed, 100 insertions, 5 deletions
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index 71eaf85..7a68d4b 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -3,9 +3,13 @@
"""Tests for reporting module."""
from __future__ import annotations
-import io
import json
from enum import Enum
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+from unittest.mock import Mock
+from unittest.mock import patch
import numpy as np
import pytest
@@ -14,13 +18,15 @@ from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
from mlia.core.reporting import Column
+from mlia.core.reporting import CustomJSONEncoder
from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
-from mlia.core.reporting import json_reporter
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.core.reporting import TextReporter
from mlia.utils.console import remove_ascii_codes
@@ -364,10 +370,9 @@ def test_custom_json_serialization() -> None:
alias="sample_table",
)
- output = io.StringIO()
- json_reporter(table, output)
+ output = json.dumps(table.to_json(), indent=4, cls=CustomJSONEncoder)
- assert json.loads(output.getvalue()) == {
+ assert json.loads(output) == {
"sample_table": [
{"column1": "value1"},
{"column1": 10.0},
@@ -375,3 +380,93 @@ def test_custom_json_serialization() -> None:
{"column1": 10},
]
}
+
+
+class TestTextReporter:
+ """Test TextReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test TextReporter."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ assert reporter.output_format == "plain_text"
+
+ def test_submit(self) -> None:
+ """Test TextReporter submit."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2", delay_print=True)
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+ assert reporter.delayed == [("test2", ANY)]
+
+ def test_print_delayed(self) -> None:
+ """Test TextReporter print_delayed."""
+ with patch(
+ "mlia.core.reporting.TextReporter.produce_report"
+ ) as mock_produce_report:
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test", delay_print=True)
+ reporter.print_delayed()
+ assert reporter.data == [("test", ANY)]
+ assert not reporter.delayed
+ mock_produce_report.assert_called()
+
+ def test_produce_report(self) -> None:
+ """Test TextReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+
+ with patch("mlia.core.reporting.logger") as mock_logger:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_plain_text()])
+ mock_logger.info.assert_called()
+
+
+class TestJSONReporter:
+ """Test JSONReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test JSONReporter."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ assert reporter.output_format == "json"
+
+ def test_submit(self) -> None:
+ """Test JSONReporter submit."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2")
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+
+ def test_generate_report(self) -> None:
+ """Test JSONReporter generate_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+
+ with patch(
+ "mlia.core.reporting.JSONReporter.produce_report"
+ ) as mock_produce_report:
+ reporter.generate_report()
+ mock_produce_report.assert_called()
+
+ @patch("builtins.print")
+ def test_produce_report(self, mock_print: Mock) -> None:
+ """Test JSONReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+
+ with patch("json.dumps") as mock_dumps:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_json()])
+ mock_dumps.assert_called()
+ mock_print.assert_called()