aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2023-01-19 14:52:36 +0000
committerRuomei Yan <ruomei.yan@arm.com>2023-02-13 14:24:02 +0000
commite3bef50932abdc2e32fa5636fb7a496b149e72d9 (patch)
treed32e184038ef0cd9e206810abe6dd44d84067163
parentcceb6fbd548d6414928f779e5572d325beb3c604 (diff)
downloadmlia-e3bef50932abdc2e32fa5636fb7a496b149e72d9.tar.gz
MLIA-775 Refactor metadata related classes
- define Metadata base class with dictionary data and abstract method - mlia, tosa, model and metadatadisplay classes are all inherited from base class - update unit tests - update function report_metadata into more generalized format Change-Id: Id49e15283eebdca705045eda81db637d82f85453
-rw-r--r--src/mlia/core/metadata.py53
-rw-r--r--src/mlia/target/tosa/advisor.py9
-rw-r--r--src/mlia/target/tosa/metadata.py9
-rw-r--r--src/mlia/target/tosa/reporters.py68
-rw-r--r--tests/test_target_tosa_reporters.py17
5 files changed, 79 insertions, 77 deletions
diff --git a/src/mlia/core/metadata.py b/src/mlia/core/metadata.py
index f0a0e03..4b4ba3e 100644
--- a/src/mlia/core/metadata.py
+++ b/src/mlia/core/metadata.py
@@ -1,37 +1,52 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Classes for metadata."""
+from __future__ import annotations
+
+from abc import ABC
+from abc import abstractmethod
from pathlib import Path
from mlia.utils.misc import get_file_checksum
from mlia.utils.misc import get_pkg_version
-class Metadata: # pylint: disable=too-few-public-methods
- """Base class metadata."""
+class Metadata(ABC): # pylint: disable=too-few-public-methods
+ """Base class for possbile metadata."""
- def __init__(self, data_name: str) -> None:
+ def __init__(self, name: str) -> None:
"""Init Metadata."""
- self.version = self.get_version(data_name)
+ self.name = name
+ self.data_dict = self.get_metadata()
- def get_version(self, data_name: str) -> str:
- """Get version of the python package."""
- return get_pkg_version(data_name)
+ @abstractmethod
+ def get_metadata(self) -> dict:
+ """Fill and return the metadata dictionary."""
-class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods
- """MLIA metadata."""
+class ModelMetadata(Metadata):
+ """Model metadata."""
+ def __init__(self, model_path: Path, name: str = "Model") -> None:
+ """Metadata for model zoo."""
+ self.model_path = model_path
+ super().__init__(name)
-class ModelMetadata: # pylint: disable=too-few-public-methods
- """Model metadata."""
+ def get_metadata(self) -> dict:
+ """Fill in metadata for model file."""
+ return {
+ "model_name": self.model_path.name,
+ "model_checksum": get_file_checksum(self.model_path),
+ }
+
+
+class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods
+ """MLIA metadata."""
- def __init__(self, path_name: Path) -> None:
- """Init ModelMetadata."""
- self.model_name = path_name.name
- self.path_name = path_name
- self.checksum = self.get_checksum()
+ def __init__(self, name: str = "MLIA") -> None:
+ """Init MLIAMetadata."""
+ super().__init__(name)
- def get_checksum(self) -> str:
- """Get checksum of the model."""
- return get_file_checksum(self.path_name)
+ def get_metadata(self) -> dict:
+ """Get mlia version."""
+ return {"mlia_version": get_pkg_version("mlia")}
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 5fb18ed..7666df4 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -16,8 +16,6 @@ from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
-from mlia.core.metadata import MLIAMetadata
-from mlia.core.metadata import ModelMetadata
from mlia.target.registry import profile
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
@@ -25,7 +23,6 @@ from mlia.target.tosa.data_analysis import TOSADataAnalyzer
from mlia.target.tosa.data_collection import TOSAOperatorCompatibility
from mlia.target.tosa.events import TOSAAdvisorStartedEvent
from mlia.target.tosa.handlers import TOSAEventHandler
-from mlia.target.tosa.metadata import TOSAMetadata
from mlia.target.tosa.reporters import MetadataDisplay
@@ -69,11 +66,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
TOSAAdvisorStartedEvent(
model,
cast(TOSAConfiguration, profile(target_profile)),
- MetadataDisplay(
- TOSAMetadata("tosa-checker"),
- MLIAMetadata("mlia"),
- ModelMetadata(model),
- ),
+ MetadataDisplay(model),
)
]
diff --git a/src/mlia/target/tosa/metadata.py b/src/mlia/target/tosa/metadata.py
index 5575207..8e1f5ca 100644
--- a/src/mlia/target/tosa/metadata.py
+++ b/src/mlia/target/tosa/metadata.py
@@ -2,7 +2,16 @@
# SPDX-License-Identifier: Apache-2.0
"""TOSA package metadata."""
from mlia.core.metadata import Metadata
+from mlia.utils.misc import get_pkg_version
class TOSAMetadata(Metadata): # pylint: disable=too-few-public-methods
"""TOSA metadata."""
+
+ def __init__(self) -> None:
+ """Init TOSAMetadata."""
+ super().__init__("tosa-checker")
+
+ def get_metadata(self) -> dict:
+ """Returen tosa version."""
+ return {"tosa_version": get_pkg_version(self.name)}
diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py
index 9575978..f54c06b 100644
--- a/src/mlia/target/tosa/reporters.py
+++ b/src/mlia/target/tosa/reporters.py
@@ -3,12 +3,14 @@
"""Reports module."""
from __future__ import annotations
+from pathlib import Path
from typing import Any
from typing import Callable
from mlia.backend.tosa_checker.compat import Operator
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
from mlia.core.advice_generation import Advice
+from mlia.core.metadata import Metadata
from mlia.core.metadata import MLIAMetadata
from mlia.core.metadata import ModelMetadata
from mlia.core.reporters import report_advice
@@ -26,20 +28,25 @@ from mlia.utils.console import style_improvement
from mlia.utils.types import is_list_of
-class MetadataDisplay: # pylint: disable=too-few-public-methods
- """TOSA metadata."""
+class MetadataDisplay(Metadata): # pylint: disable=too-few-public-methods
+ """TOSA metadata display items."""
- def __init__(
- self,
- tosa_meta: TOSAMetadata,
- mlia_meta: MLIAMetadata,
- model_meta: ModelMetadata,
- ) -> None:
- """Init TOSAMetadata."""
- self.tosa_version = tosa_meta.version
- self.mlia_version = mlia_meta.version
- self.model_check_sum = model_meta.checksum
- self.model_name = model_meta.model_name
+ def __init__(self, model_path: Path) -> None:
+ """Init MetadataDisplay."""
+ self.model_path = model_path
+ super().__init__("Metadata")
+
+ def get_metadata(self) -> dict:
+ """Combine all necessary elements for display."""
+ all_data = {
+ data_dict.name: data_dict.data_dict
+ for data_dict in (
+ TOSAMetadata(),
+ MLIAMetadata(),
+ ModelMetadata(self.model_path),
+ )
+ }
+ return all_data
def report_target(target_config: TOSAConfiguration) -> Report:
@@ -54,31 +61,16 @@ def report_target(target_config: TOSAConfiguration) -> Report:
def report_metadata(data: MetadataDisplay) -> Report:
- """Generate report for the package version."""
- return NestedReport(
- "Metadata",
- "metadata",
- [
- ReportItem(
- "TOSA checker",
- alias="tosa-checker",
- nested_items=[ReportItem("version", "version", data.tosa_version)],
- ),
- ReportItem(
- "MLIA",
- "MLIA",
- nested_items=[ReportItem("version", "version", data.mlia_version)],
- ),
- ReportItem(
- "Model",
- "Model",
- nested_items=[
- ReportItem("name", "name", data.model_name),
- ReportItem("checksum", "checksum", data.model_check_sum),
- ],
- ),
- ],
- )
+ """Generate report for the metadata."""
+ items: list[ReportItem] = []
+
+ for key, sub_dict in data.data_dict.items():
+ nested_items = [
+ ReportItem(key, alias=key, value=val) for key, val in sub_dict.items()
+ ]
+ items.append(ReportItem(key, alias=key, nested_items=nested_items))
+
+ return NestedReport("Metadata", "metadata", items)
def report_tosa_operators(ops: list[Operator]) -> Report:
diff --git a/tests/test_target_tosa_reporters.py b/tests/test_target_tosa_reporters.py
index 0578b1a..5f26c20 100644
--- a/tests/test_target_tosa_reporters.py
+++ b/tests/test_target_tosa_reporters.py
@@ -6,11 +6,8 @@ from unittest.mock import MagicMock
import pytest
-from mlia.core.metadata import MLIAMetadata
-from mlia.core.metadata import ModelMetadata
from mlia.core.reporting import Report
from mlia.target.tosa.config import TOSAConfiguration
-from mlia.target.tosa.metadata import TOSAMetadata
from mlia.target.tosa.reporters import MetadataDisplay
from mlia.target.tosa.reporters import report_target
from mlia.target.tosa.reporters import tosa_formatters
@@ -28,18 +25,14 @@ def test_tosa_formatters(
"""Test function tosa_formatters() with valid input."""
mock_version = MagicMock()
monkeypatch.setattr(
- "mlia.core.metadata.get_pkg_version",
+ "mlia.target.tosa.metadata.get_pkg_version",
MagicMock(return_value=mock_version),
)
- data = MetadataDisplay(
- TOSAMetadata("tosa-checker"),
- MLIAMetadata("mlia"),
- ModelMetadata(test_tflite_model),
- )
- formatter = tosa_formatters(data)
- report = formatter(data)
- assert data.tosa_version == mock_version
+ display_data = MetadataDisplay(test_tflite_model)
+ formatter = tosa_formatters(MetadataDisplay(test_tflite_model))
+ report = formatter(display_data)
+ assert display_data.data_dict["tosa-checker"]["tosa_version"] == mock_version
assert isinstance(report, Report)