diff options
Diffstat (limited to 'tests/test_core_reporting.py')
-rw-r--r-- | tests/test_core_reporting.py | 105 |
1 files changed, 100 insertions, 5 deletions
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py index 71eaf85..7a68d4b 100644 --- a/tests/test_core_reporting.py +++ b/tests/test_core_reporting.py @@ -3,9 +3,13 @@ """Tests for reporting module.""" from __future__ import annotations -import io import json from enum import Enum +from unittest.mock import ANY +from unittest.mock import call +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import patch import numpy as np import pytest @@ -14,13 +18,15 @@ from mlia.core.reporting import BytesCell from mlia.core.reporting import Cell from mlia.core.reporting import ClockCell from mlia.core.reporting import Column +from mlia.core.reporting import CustomJSONEncoder from mlia.core.reporting import CyclesCell from mlia.core.reporting import Format -from mlia.core.reporting import json_reporter +from mlia.core.reporting import JSONReporter from mlia.core.reporting import NestedReport from mlia.core.reporting import ReportItem from mlia.core.reporting import SingleRow from mlia.core.reporting import Table +from mlia.core.reporting import TextReporter from mlia.utils.console import remove_ascii_codes @@ -364,10 +370,9 @@ def test_custom_json_serialization() -> None: alias="sample_table", ) - output = io.StringIO() - json_reporter(table, output) + output = json.dumps(table.to_json(), indent=4, cls=CustomJSONEncoder) - assert json.loads(output.getvalue()) == { + assert json.loads(output) == { "sample_table": [ {"column1": "value1"}, {"column1": 10.0}, @@ -375,3 +380,93 @@ def test_custom_json_serialization() -> None: {"column1": 10}, ] } + + +class TestTextReporter: + """Test TextReporter methods.""" + + def test_text_reporter(self) -> None: + """Test TextReporter.""" + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + assert reporter.output_format == "plain_text" + + def test_submit(self) -> None: + """Test TextReporter submit.""" + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + reporter.submit("test") + assert reporter.data == [("test", ANY)] + + reporter.submit("test2", delay_print=True) + assert reporter.data == [("test", ANY), ("test2", ANY)] + assert reporter.delayed == [("test2", ANY)] + + def test_print_delayed(self) -> None: + """Test TextReporter print_delayed.""" + with patch( + "mlia.core.reporting.TextReporter.produce_report" + ) as mock_produce_report: + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + reporter.submit("test", delay_print=True) + reporter.print_delayed() + assert reporter.data == [("test", ANY)] + assert not reporter.delayed + mock_produce_report.assert_called() + + def test_produce_report(self) -> None: + """Test TextReporter produce_report.""" + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + + with patch("mlia.core.reporting.logger") as mock_logger: + mock_formatter = MagicMock() + reporter.produce_report("test", mock_formatter) + mock_formatter.assert_has_calls([call("test"), call().to_plain_text()]) + mock_logger.info.assert_called() + + +class TestJSONReporter: + """Test JSONReporter methods.""" + + def test_text_reporter(self) -> None: + """Test JSONReporter.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + assert reporter.output_format == "json" + + def test_submit(self) -> None: + """Test JSONReporter submit.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + reporter.submit("test") + assert reporter.data == [("test", ANY)] + + reporter.submit("test2") + assert reporter.data == [("test", ANY), ("test2", ANY)] + + def test_generate_report(self) -> None: + """Test JSONReporter generate_report.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + reporter.submit("test") + + with patch( + "mlia.core.reporting.JSONReporter.produce_report" + ) as mock_produce_report: + reporter.generate_report() + mock_produce_report.assert_called() + + @patch("builtins.print") + def test_produce_report(self, mock_print: Mock) -> None: + """Test JSONReporter produce_report.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + + with patch("json.dumps") as mock_dumps: + mock_formatter = MagicMock() + reporter.produce_report("test", mock_formatter) + mock_formatter.assert_has_calls([call("test"), call().to_json()]) + mock_dumps.assert_called() + mock_print.assert_called() |