From e3bef50932abdc2e32fa5636fb7a496b149e72d9 Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Thu, 19 Jan 2023 14:52:36 +0000 Subject: MLIA-775 Refactor metadata related classes - define Metadata base class with dictionary data and abstract method - mlia, tosa, model and metadatadisplay classes are all inherited from base class - update unit tests - update function report_metadata into more generalized format Change-Id: Id49e15283eebdca705045eda81db637d82f85453 --- src/mlia/core/metadata.py | 53 ++++++++++++++++++----------- src/mlia/target/tosa/advisor.py | 9 +---- src/mlia/target/tosa/metadata.py | 9 +++++ src/mlia/target/tosa/reporters.py | 68 ++++++++++++++++--------------------- tests/test_target_tosa_reporters.py | 17 +++------- 5 files changed, 79 insertions(+), 77 deletions(-) diff --git a/src/mlia/core/metadata.py b/src/mlia/core/metadata.py index f0a0e03..4b4ba3e 100644 --- a/src/mlia/core/metadata.py +++ b/src/mlia/core/metadata.py @@ -1,37 +1,52 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Classes for metadata.""" +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod from pathlib import Path from mlia.utils.misc import get_file_checksum from mlia.utils.misc import get_pkg_version -class Metadata: # pylint: disable=too-few-public-methods - """Base class metadata.""" +class Metadata(ABC): # pylint: disable=too-few-public-methods + """Base class for possbile metadata.""" - def __init__(self, data_name: str) -> None: + def __init__(self, name: str) -> None: """Init Metadata.""" - self.version = self.get_version(data_name) + self.name = name + self.data_dict = self.get_metadata() - def get_version(self, data_name: str) -> str: - """Get version of the python package.""" - return get_pkg_version(data_name) + @abstractmethod + def get_metadata(self) -> dict: + """Fill and return the metadata dictionary.""" -class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods - """MLIA metadata.""" +class ModelMetadata(Metadata): + """Model metadata.""" + def __init__(self, model_path: Path, name: str = "Model") -> None: + """Metadata for model zoo.""" + self.model_path = model_path + super().__init__(name) -class ModelMetadata: # pylint: disable=too-few-public-methods - """Model metadata.""" + def get_metadata(self) -> dict: + """Fill in metadata for model file.""" + return { + "model_name": self.model_path.name, + "model_checksum": get_file_checksum(self.model_path), + } + + +class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods + """MLIA metadata.""" - def __init__(self, path_name: Path) -> None: - """Init ModelMetadata.""" - self.model_name = path_name.name - self.path_name = path_name - self.checksum = self.get_checksum() + def __init__(self, name: str = "MLIA") -> None: + """Init MLIAMetadata.""" + super().__init__(name) - def get_checksum(self) -> str: - """Get checksum of the model.""" - return get_file_checksum(self.path_name) + def get_metadata(self) -> dict: + """Get mlia version.""" + return {"mlia_version": get_pkg_version("mlia")} diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 5fb18ed..7666df4 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -16,8 +16,6 @@ from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event -from mlia.core.metadata import MLIAMetadata -from mlia.core.metadata import ModelMetadata from mlia.target.registry import profile from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration @@ -25,7 +23,6 @@ from mlia.target.tosa.data_analysis import TOSADataAnalyzer from mlia.target.tosa.data_collection import TOSAOperatorCompatibility from mlia.target.tosa.events import TOSAAdvisorStartedEvent from mlia.target.tosa.handlers import TOSAEventHandler -from mlia.target.tosa.metadata import TOSAMetadata from mlia.target.tosa.reporters import MetadataDisplay @@ -69,11 +66,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): TOSAAdvisorStartedEvent( model, cast(TOSAConfiguration, profile(target_profile)), - MetadataDisplay( - TOSAMetadata("tosa-checker"), - MLIAMetadata("mlia"), - ModelMetadata(model), - ), + MetadataDisplay(model), ) ] diff --git a/src/mlia/target/tosa/metadata.py b/src/mlia/target/tosa/metadata.py index 5575207..8e1f5ca 100644 --- a/src/mlia/target/tosa/metadata.py +++ b/src/mlia/target/tosa/metadata.py @@ -2,7 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 """TOSA package metadata.""" from mlia.core.metadata import Metadata +from mlia.utils.misc import get_pkg_version class TOSAMetadata(Metadata): # pylint: disable=too-few-public-methods """TOSA metadata.""" + + def __init__(self) -> None: + """Init TOSAMetadata.""" + super().__init__("tosa-checker") + + def get_metadata(self) -> dict: + """Returen tosa version.""" + return {"tosa_version": get_pkg_version(self.name)} diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py index 9575978..f54c06b 100644 --- a/src/mlia/target/tosa/reporters.py +++ b/src/mlia/target/tosa/reporters.py @@ -3,12 +3,14 @@ """Reports module.""" from __future__ import annotations +from pathlib import Path from typing import Any from typing import Callable from mlia.backend.tosa_checker.compat import Operator from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.advice_generation import Advice +from mlia.core.metadata import Metadata from mlia.core.metadata import MLIAMetadata from mlia.core.metadata import ModelMetadata from mlia.core.reporters import report_advice @@ -26,20 +28,25 @@ from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of -class MetadataDisplay: # pylint: disable=too-few-public-methods - """TOSA metadata.""" +class MetadataDisplay(Metadata): # pylint: disable=too-few-public-methods + """TOSA metadata display items.""" - def __init__( - self, - tosa_meta: TOSAMetadata, - mlia_meta: MLIAMetadata, - model_meta: ModelMetadata, - ) -> None: - """Init TOSAMetadata.""" - self.tosa_version = tosa_meta.version - self.mlia_version = mlia_meta.version - self.model_check_sum = model_meta.checksum - self.model_name = model_meta.model_name + def __init__(self, model_path: Path) -> None: + """Init MetadataDisplay.""" + self.model_path = model_path + super().__init__("Metadata") + + def get_metadata(self) -> dict: + """Combine all necessary elements for display.""" + all_data = { + data_dict.name: data_dict.data_dict + for data_dict in ( + TOSAMetadata(), + MLIAMetadata(), + ModelMetadata(self.model_path), + ) + } + return all_data def report_target(target_config: TOSAConfiguration) -> Report: @@ -54,31 +61,16 @@ def report_target(target_config: TOSAConfiguration) -> Report: def report_metadata(data: MetadataDisplay) -> Report: - """Generate report for the package version.""" - return NestedReport( - "Metadata", - "metadata", - [ - ReportItem( - "TOSA checker", - alias="tosa-checker", - nested_items=[ReportItem("version", "version", data.tosa_version)], - ), - ReportItem( - "MLIA", - "MLIA", - nested_items=[ReportItem("version", "version", data.mlia_version)], - ), - ReportItem( - "Model", - "Model", - nested_items=[ - ReportItem("name", "name", data.model_name), - ReportItem("checksum", "checksum", data.model_check_sum), - ], - ), - ], - ) + """Generate report for the metadata.""" + items: list[ReportItem] = [] + + for key, sub_dict in data.data_dict.items(): + nested_items = [ + ReportItem(key, alias=key, value=val) for key, val in sub_dict.items() + ] + items.append(ReportItem(key, alias=key, nested_items=nested_items)) + + return NestedReport("Metadata", "metadata", items) def report_tosa_operators(ops: list[Operator]) -> Report: diff --git a/tests/test_target_tosa_reporters.py b/tests/test_target_tosa_reporters.py index 0578b1a..5f26c20 100644 --- a/tests/test_target_tosa_reporters.py +++ b/tests/test_target_tosa_reporters.py @@ -6,11 +6,8 @@ from unittest.mock import MagicMock import pytest -from mlia.core.metadata import MLIAMetadata -from mlia.core.metadata import ModelMetadata from mlia.core.reporting import Report from mlia.target.tosa.config import TOSAConfiguration -from mlia.target.tosa.metadata import TOSAMetadata from mlia.target.tosa.reporters import MetadataDisplay from mlia.target.tosa.reporters import report_target from mlia.target.tosa.reporters import tosa_formatters @@ -28,18 +25,14 @@ def test_tosa_formatters( """Test function tosa_formatters() with valid input.""" mock_version = MagicMock() monkeypatch.setattr( - "mlia.core.metadata.get_pkg_version", + "mlia.target.tosa.metadata.get_pkg_version", MagicMock(return_value=mock_version), ) - data = MetadataDisplay( - TOSAMetadata("tosa-checker"), - MLIAMetadata("mlia"), - ModelMetadata(test_tflite_model), - ) - formatter = tosa_formatters(data) - report = formatter(data) - assert data.tosa_version == mock_version + display_data = MetadataDisplay(test_tflite_model) + formatter = tosa_formatters(MetadataDisplay(test_tflite_model)) + report = formatter(display_data) + assert display_data.data_dict["tosa-checker"]["tosa_version"] == mock_version assert isinstance(report, Report) -- cgit v1.2.1