From e3bef50932abdc2e32fa5636fb7a496b149e72d9 Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Thu, 19 Jan 2023 14:52:36 +0000 Subject: MLIA-775 Refactor metadata related classes - define Metadata base class with dictionary data and abstract method - mlia, tosa, model and metadatadisplay classes are all inherited from base class - update unit tests - update function report_metadata into more generalized format Change-Id: Id49e15283eebdca705045eda81db637d82f85453 --- src/mlia/core/metadata.py | 53 ++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 19 deletions(-) (limited to 'src/mlia/core') diff --git a/src/mlia/core/metadata.py b/src/mlia/core/metadata.py index f0a0e03..4b4ba3e 100644 --- a/src/mlia/core/metadata.py +++ b/src/mlia/core/metadata.py @@ -1,37 +1,52 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Classes for metadata.""" +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod from pathlib import Path from mlia.utils.misc import get_file_checksum from mlia.utils.misc import get_pkg_version -class Metadata: # pylint: disable=too-few-public-methods - """Base class metadata.""" +class Metadata(ABC): # pylint: disable=too-few-public-methods + """Base class for possbile metadata.""" - def __init__(self, data_name: str) -> None: + def __init__(self, name: str) -> None: """Init Metadata.""" - self.version = self.get_version(data_name) + self.name = name + self.data_dict = self.get_metadata() - def get_version(self, data_name: str) -> str: - """Get version of the python package.""" - return get_pkg_version(data_name) + @abstractmethod + def get_metadata(self) -> dict: + """Fill and return the metadata dictionary.""" -class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods - """MLIA metadata.""" +class ModelMetadata(Metadata): + """Model metadata.""" + def __init__(self, model_path: Path, name: str = "Model") -> None: + """Metadata for model zoo.""" + self.model_path = model_path + super().__init__(name) -class ModelMetadata: # pylint: disable=too-few-public-methods - """Model metadata.""" + def get_metadata(self) -> dict: + """Fill in metadata for model file.""" + return { + "model_name": self.model_path.name, + "model_checksum": get_file_checksum(self.model_path), + } + + +class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods + """MLIA metadata.""" - def __init__(self, path_name: Path) -> None: - """Init ModelMetadata.""" - self.model_name = path_name.name - self.path_name = path_name - self.checksum = self.get_checksum() + def __init__(self, name: str = "MLIA") -> None: + """Init MLIAMetadata.""" + super().__init__(name) - def get_checksum(self) -> str: - """Get checksum of the model.""" - return get_file_checksum(self.path_name) + def get_metadata(self) -> dict: + """Get mlia version.""" + return {"mlia_version": get_pkg_version("mlia")} -- cgit v1.2.1