aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core/metadata.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/core/metadata.py')
-rw-r--r--src/mlia/core/metadata.py53
1 files changed, 34 insertions, 19 deletions
diff --git a/src/mlia/core/metadata.py b/src/mlia/core/metadata.py
index f0a0e03..4b4ba3e 100644
--- a/src/mlia/core/metadata.py
+++ b/src/mlia/core/metadata.py
@@ -1,37 +1,52 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Classes for metadata."""
+from __future__ import annotations
+
+from abc import ABC
+from abc import abstractmethod
from pathlib import Path
from mlia.utils.misc import get_file_checksum
from mlia.utils.misc import get_pkg_version
-class Metadata: # pylint: disable=too-few-public-methods
- """Base class metadata."""
+class Metadata(ABC): # pylint: disable=too-few-public-methods
+ """Base class for possbile metadata."""
- def __init__(self, data_name: str) -> None:
+ def __init__(self, name: str) -> None:
"""Init Metadata."""
- self.version = self.get_version(data_name)
+ self.name = name
+ self.data_dict = self.get_metadata()
- def get_version(self, data_name: str) -> str:
- """Get version of the python package."""
- return get_pkg_version(data_name)
+ @abstractmethod
+ def get_metadata(self) -> dict:
+ """Fill and return the metadata dictionary."""
-class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods
- """MLIA metadata."""
+class ModelMetadata(Metadata):
+ """Model metadata."""
+ def __init__(self, model_path: Path, name: str = "Model") -> None:
+ """Metadata for model zoo."""
+ self.model_path = model_path
+ super().__init__(name)
-class ModelMetadata: # pylint: disable=too-few-public-methods
- """Model metadata."""
+ def get_metadata(self) -> dict:
+ """Fill in metadata for model file."""
+ return {
+ "model_name": self.model_path.name,
+ "model_checksum": get_file_checksum(self.model_path),
+ }
+
+
+class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods
+ """MLIA metadata."""
- def __init__(self, path_name: Path) -> None:
- """Init ModelMetadata."""
- self.model_name = path_name.name
- self.path_name = path_name
- self.checksum = self.get_checksum()
+ def __init__(self, name: str = "MLIA") -> None:
+ """Init MLIAMetadata."""
+ super().__init__(name)
- def get_checksum(self) -> str:
- """Get checksum of the model."""
- return get_file_checksum(self.path_name)
+ def get_metadata(self) -> dict:
+ """Get mlia version."""
+ return {"mlia_version": get_pkg_version("mlia")}