aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2023-01-30 14:42:24 +0000
committerDmitrii Agibov <dmitrii.agibov@arm.com>2023-01-31 09:25:07 +0000
commitd1b2374cda6811a93d1174400fc2eecd7100a8c3 (patch)
treeec54c94317fb2d2bdbabec351d1c52f555bb8cc0 /src/mlia/core
parentb0c1ddad3db24ffcdb06a52b75eb2c87879e7ad9 (diff)
downloadmlia-d1b2374cda6811a93d1174400fc2eecd7100a8c3.tar.gz
MLIA-785 Enable export into json for enums
Change-Id: I8e4d5d04f6b1b252dae872ea76d2bd8c41f4b376
Diffstat (limited to 'src/mlia/core')
-rw-r--r--src/mlia/core/reporting.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index 19644b2..ad63d62 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -11,6 +11,7 @@ from collections import defaultdict
from contextlib import contextmanager
from contextlib import ExitStack
from dataclasses import dataclass
+from enum import Enum
from functools import partial
from io import TextIOWrapper
from pathlib import Path
@@ -491,13 +492,16 @@ class CustomJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder."""
def default(self, o: Any) -> Any:
- """Support numpy types."""
+ """Support custom types."""
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
+ if isinstance(o, Enum) and isinstance(o.value, str):
+ return o.value
+
return json.JSONEncoder.default(self, o)