From e3bef50932abdc2e32fa5636fb7a496b149e72d9 Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Thu, 19 Jan 2023 14:52:36 +0000 Subject: MLIA-775 Refactor metadata related classes - define Metadata base class with dictionary data and abstract method - mlia, tosa, model and metadatadisplay classes are all inherited from base class - update unit tests - update function report_metadata into more generalized format Change-Id: Id49e15283eebdca705045eda81db637d82f85453 --- src/mlia/target/tosa/advisor.py | 9 +----- src/mlia/target/tosa/metadata.py | 9 ++++++ src/mlia/target/tosa/reporters.py | 68 +++++++++++++++++---------------------- 3 files changed, 40 insertions(+), 46 deletions(-) (limited to 'src/mlia/target/tosa') diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 5fb18ed..7666df4 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -16,8 +16,6 @@ from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event -from mlia.core.metadata import MLIAMetadata -from mlia.core.metadata import ModelMetadata from mlia.target.registry import profile from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration @@ -25,7 +23,6 @@ from mlia.target.tosa.data_analysis import TOSADataAnalyzer from mlia.target.tosa.data_collection import TOSAOperatorCompatibility from mlia.target.tosa.events import TOSAAdvisorStartedEvent from mlia.target.tosa.handlers import TOSAEventHandler -from mlia.target.tosa.metadata import TOSAMetadata from mlia.target.tosa.reporters import MetadataDisplay @@ -69,11 +66,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): TOSAAdvisorStartedEvent( model, cast(TOSAConfiguration, profile(target_profile)), - MetadataDisplay( - TOSAMetadata("tosa-checker"), - MLIAMetadata("mlia"), - ModelMetadata(model), - ), + MetadataDisplay(model), ) ] diff --git a/src/mlia/target/tosa/metadata.py b/src/mlia/target/tosa/metadata.py index 5575207..8e1f5ca 100644 --- a/src/mlia/target/tosa/metadata.py +++ b/src/mlia/target/tosa/metadata.py @@ -2,7 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 """TOSA package metadata.""" from mlia.core.metadata import Metadata +from mlia.utils.misc import get_pkg_version class TOSAMetadata(Metadata): # pylint: disable=too-few-public-methods """TOSA metadata.""" + + def __init__(self) -> None: + """Init TOSAMetadata.""" + super().__init__("tosa-checker") + + def get_metadata(self) -> dict: + """Returen tosa version.""" + return {"tosa_version": get_pkg_version(self.name)} diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py index 9575978..f54c06b 100644 --- a/src/mlia/target/tosa/reporters.py +++ b/src/mlia/target/tosa/reporters.py @@ -3,12 +3,14 @@ """Reports module.""" from __future__ import annotations +from pathlib import Path from typing import Any from typing import Callable from mlia.backend.tosa_checker.compat import Operator from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.advice_generation import Advice +from mlia.core.metadata import Metadata from mlia.core.metadata import MLIAMetadata from mlia.core.metadata import ModelMetadata from mlia.core.reporters import report_advice @@ -26,20 +28,25 @@ from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of -class MetadataDisplay: # pylint: disable=too-few-public-methods - """TOSA metadata.""" +class MetadataDisplay(Metadata): # pylint: disable=too-few-public-methods + """TOSA metadata display items.""" - def __init__( - self, - tosa_meta: TOSAMetadata, - mlia_meta: MLIAMetadata, - model_meta: ModelMetadata, - ) -> None: - """Init TOSAMetadata.""" - self.tosa_version = tosa_meta.version - self.mlia_version = mlia_meta.version - self.model_check_sum = model_meta.checksum - self.model_name = model_meta.model_name + def __init__(self, model_path: Path) -> None: + """Init MetadataDisplay.""" + self.model_path = model_path + super().__init__("Metadata") + + def get_metadata(self) -> dict: + """Combine all necessary elements for display.""" + all_data = { + data_dict.name: data_dict.data_dict + for data_dict in ( + TOSAMetadata(), + MLIAMetadata(), + ModelMetadata(self.model_path), + ) + } + return all_data def report_target(target_config: TOSAConfiguration) -> Report: @@ -54,31 +61,16 @@ def report_target(target_config: TOSAConfiguration) -> Report: def report_metadata(data: MetadataDisplay) -> Report: - """Generate report for the package version.""" - return NestedReport( - "Metadata", - "metadata", - [ - ReportItem( - "TOSA checker", - alias="tosa-checker", - nested_items=[ReportItem("version", "version", data.tosa_version)], - ), - ReportItem( - "MLIA", - "MLIA", - nested_items=[ReportItem("version", "version", data.mlia_version)], - ), - ReportItem( - "Model", - "Model", - nested_items=[ - ReportItem("name", "name", data.model_name), - ReportItem("checksum", "checksum", data.model_check_sum), - ], - ), - ], - ) + """Generate report for the metadata.""" + items: list[ReportItem] = [] + + for key, sub_dict in data.data_dict.items(): + nested_items = [ + ReportItem(key, alias=key, value=val) for key, val in sub_dict.items() + ] + items.append(ReportItem(key, alias=key, nested_items=nested_items)) + + return NestedReport("Metadata", "metadata", items) def report_tosa_operators(ops: list[Operator]) -> Report: -- cgit v1.2.1