diff options
Diffstat (limited to 'src/mlia/core/metadata.py')
-rw-r--r-- | src/mlia/core/metadata.py | 53 |
1 files changed, 34 insertions, 19 deletions
diff --git a/src/mlia/core/metadata.py b/src/mlia/core/metadata.py index f0a0e03..4b4ba3e 100644 --- a/src/mlia/core/metadata.py +++ b/src/mlia/core/metadata.py @@ -1,37 +1,52 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Classes for metadata.""" +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod from pathlib import Path from mlia.utils.misc import get_file_checksum from mlia.utils.misc import get_pkg_version -class Metadata: # pylint: disable=too-few-public-methods - """Base class metadata.""" +class Metadata(ABC): # pylint: disable=too-few-public-methods + """Base class for possbile metadata.""" - def __init__(self, data_name: str) -> None: + def __init__(self, name: str) -> None: """Init Metadata.""" - self.version = self.get_version(data_name) + self.name = name + self.data_dict = self.get_metadata() - def get_version(self, data_name: str) -> str: - """Get version of the python package.""" - return get_pkg_version(data_name) + @abstractmethod + def get_metadata(self) -> dict: + """Fill and return the metadata dictionary.""" -class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods - """MLIA metadata.""" +class ModelMetadata(Metadata): + """Model metadata.""" + def __init__(self, model_path: Path, name: str = "Model") -> None: + """Metadata for model zoo.""" + self.model_path = model_path + super().__init__(name) -class ModelMetadata: # pylint: disable=too-few-public-methods - """Model metadata.""" + def get_metadata(self) -> dict: + """Fill in metadata for model file.""" + return { + "model_name": self.model_path.name, + "model_checksum": get_file_checksum(self.model_path), + } + + +class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods + """MLIA metadata.""" - def __init__(self, path_name: Path) -> None: - """Init ModelMetadata.""" - self.model_name = path_name.name - self.path_name = path_name - self.checksum = self.get_checksum() + def __init__(self, name: str = "MLIA") -> None: + """Init MLIAMetadata.""" + super().__init__(name) - def get_checksum(self) -> str: - """Get checksum of the model.""" - return get_file_checksum(self.path_name) + def get_metadata(self) -> dict: + """Get mlia version.""" + return {"mlia_version": get_pkg_version("mlia")} |