aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2023-01-30 14:42:24 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2023-01-31 09:25:07 +0000
commitd1b2374cda6811a93d1174400fc2eecd7100a8c3 (patch)
treeec54c94317fb2d2bdbabec351d1c52f555bb8cc0
parentb0c1ddad3db24ffcdb06a52b75eb2c87879e7ad9 (diff)
downloadmlia-d1b2374cda6811a93d1174400fc2eecd7100a8c3.tar.gz
MLIA-785 Enable export into json for enums
Change-Id: I8e4d5d04f6b1b252dae872ea76d2bd8c41f4b376
-rw-r--r--src/mlia/core/reporting.py6
-rw-r--r--tests/test_core_reporting.py40
2 files changed, 45 insertions, 1 deletions
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index 19644b2..ad63d62 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -11,6 +11,7 @@ from collections import defaultdict
from contextlib import contextmanager
from contextlib import ExitStack
from dataclasses import dataclass
+from enum import Enum
from functools import partial
from io import TextIOWrapper
from pathlib import Path
@@ -491,13 +492,16 @@ class CustomJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder."""
def default(self, o: Any) -> Any:
- """Support numpy types."""
+ """Support custom types."""
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
+ if isinstance(o, Enum) and isinstance(o.value, str):
+ return o.value
+
return json.JSONEncoder.default(self, o)
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index 7b26173..71eaf85 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -3,6 +3,11 @@
"""Tests for reporting module."""
from __future__ import annotations
+import io
+import json
+from enum import Enum
+
+import numpy as np
import pytest
from mlia.core.reporting import BytesCell
@@ -11,6 +16,7 @@ from mlia.core.reporting import ClockCell
from mlia.core.reporting import Column
from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
+from mlia.core.reporting import json_reporter
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
from mlia.core.reporting import SingleRow
@@ -335,3 +341,37 @@ Single row example:
alias="simple_row_example",
)
wrong_single_row.to_plain_text()
+
+
+def test_custom_json_serialization() -> None:
+ """Test JSON serialization for custom types."""
+
+ class TestEnum(Enum):
+ """Test enum."""
+
+ VALUE1 = "value1"
+ VALUE2 = "value2"
+
+ table = Table(
+ [Column("Column1", alias="column1")],
+ rows=[
+ [TestEnum.VALUE1],
+ [np.float64(10)],
+ [np.int64(10)],
+ [10],
+ ],
+ name="sample_table",
+ alias="sample_table",
+ )
+
+ output = io.StringIO()
+ json_reporter(table, output)
+
+ assert json.loads(output.getvalue()) == {
+ "sample_table": [
+ {"column1": "value1"},
+ {"column1": 10.0},
+ {"column1": 10},
+ {"column1": 10},
+ ]
+ }