From 4eb3fef8e5876c69dc6bac70fdc010805d5b97f2 Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Tue, 13 Dec 2022 22:02:21 +0000 Subject: MLIA-741/2 Report test results - add version extraction function in compat.py - create Metadata, MLIAMetadata, TOSAMetadata and MetadataDisplay classes - update the reporting functions so tosa and mlia version will be displayed in output json - update unit test test_configure_and_get_tosa_advisor to mock the get_events function - update the copyright information of all changed/added files - handle exception and report to json when program crashes - write new context managers for capturing stderr and stdout - support reporting stderr to json output - support reporting model checksum and model name to json output - made changes in test_e2e.py handling {model_name} replacement in --output - add unit tests Change-Id: I6629fd1c5754378e6accd488217c83d87c7eb6f1 --- src/mlia/backend/tosa_checker/compat.py | 56 +++++++++++++++++--- src/mlia/core/metadata.py | 37 +++++++++++++ src/mlia/target/tosa/advisor.py | 14 ++++- src/mlia/target/tosa/events.py | 6 ++- src/mlia/target/tosa/handlers.py | 3 +- src/mlia/target/tosa/metadata.py | 8 +++ src/mlia/target/tosa/reporters.py | 93 ++++++++++++++++++++++++++++++++- src/mlia/utils/logging.py | 37 +++++++++++-- src/mlia/utils/misc.py | 24 ++++++++- 9 files changed, 261 insertions(+), 17 deletions(-) create mode 100644 src/mlia/core/metadata.py create mode 100644 src/mlia/target/tosa/metadata.py (limited to 'src') diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py index bd21774..81f3015 100644 --- a/src/mlia/backend/tosa_checker/compat.py +++ b/src/mlia/backend/tosa_checker/compat.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA compatibility module.""" from __future__ import annotations +import sys from dataclasses import dataclass from typing import Any from typing import cast @@ -10,6 +11,7 @@ from typing import Protocol from mlia.backend.errors import BackendUnavailableError from mlia.core.typing import PathOrFileLike +from mlia.utils.logging import capture_raw_output class TOSAChecker(Protocol): @@ -37,25 +39,65 @@ class TOSACompatibilityInfo: tosa_compatible: bool operators: list[Operator] + exception: Exception | None = None + errors: list[str] | None = None + std_out: list[str] | None = None def get_tosa_compatibility_info( tflite_model_path: PathOrFileLike, ) -> TOSACompatibilityInfo: """Return list of the operators.""" - checker = get_tosa_checker(tflite_model_path) + # Capture the possible exception in running get_tosa_checker + try: + with capture_raw_output(sys.stdout) as std_output_pkg, capture_raw_output( + sys.stderr + ) as stderr_output_pkg: + checker = get_tosa_checker(tflite_model_path) + except Exception as exc: # pylint: disable=broad-except + return TOSACompatibilityInfo( + tosa_compatible=False, + operators=[], + exception=exc, + errors=None, + std_out=None, + ) + # Capture the possible BackendUnavailableError when tosa-checker is not available if checker is None: raise BackendUnavailableError( "Backend tosa-checker is not available", "tosa-checker" ) - ops = [ - Operator(item.location, item.name, item.is_tosa_compatible) - for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access - ] + # Capture the possible exception when checking ops compatibility + try: + with capture_raw_output(sys.stdout) as std_output_ops, capture_raw_output( + sys.stderr + ) as stderr_output_ops: + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + except Exception as exc: # pylint: disable=broad-except + return TOSACompatibilityInfo( + tosa_compatible=False, + operators=[], + exception=exc, + errors=None, + std_out=None, + ) - return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + # Concatenate all possbile stderr/stdout + stderr_output = stderr_output_pkg + stderr_output_ops + std_output = std_output_pkg + std_output_ops + + return TOSACompatibilityInfo( + tosa_compatible=checker.is_tosa_compatible(), + operators=ops, + exception=None, + errors=stderr_output, + std_out=std_output, + ) def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: diff --git a/src/mlia/core/metadata.py b/src/mlia/core/metadata.py new file mode 100644 index 0000000..f0a0e03 --- /dev/null +++ b/src/mlia/core/metadata.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Classes for metadata.""" +from pathlib import Path + +from mlia.utils.misc import get_file_checksum +from mlia.utils.misc import get_pkg_version + + +class Metadata: # pylint: disable=too-few-public-methods + """Base class metadata.""" + + def __init__(self, data_name: str) -> None: + """Init Metadata.""" + self.version = self.get_version(data_name) + + def get_version(self, data_name: str) -> str: + """Get version of the python package.""" + return get_pkg_version(data_name) + + +class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods + """MLIA metadata.""" + + +class ModelMetadata: # pylint: disable=too-few-public-methods + """Model metadata.""" + + def __init__(self, path_name: Path) -> None: + """Init ModelMetadata.""" + self.model_name = path_name.name + self.path_name = path_name + self.checksum = self.get_checksum() + + def get_checksum(self) -> str: + """Get checksum of the model.""" + return get_file_checksum(self.path_name) diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 4851113..0da44db 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -16,12 +16,16 @@ from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event +from mlia.core.metadata import MLIAMetadata +from mlia.core.metadata import ModelMetadata from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration from mlia.target.tosa.data_analysis import TOSADataAnalyzer from mlia.target.tosa.data_collection import TOSAOperatorCompatibility from mlia.target.tosa.events import TOSAAdvisorStartedEvent from mlia.target.tosa.handlers import TOSAEventHandler +from mlia.target.tosa.metadata import TOSAMetadata +from mlia.target.tosa.reporters import MetadataDisplay class TOSAInferenceAdvisor(DefaultInferenceAdvisor): @@ -61,7 +65,15 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): target_profile = self.get_target_profile(context) return [ - TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)), + TOSAAdvisorStartedEvent( + model, + TOSAConfiguration(target_profile), + MetadataDisplay( + TOSAMetadata("tosa-checker"), + MLIAMetadata("mlia"), + ModelMetadata(model), + ), + ) ] diff --git a/src/mlia/target/tosa/events.py b/src/mlia/target/tosa/events.py index 67d499d..cbfd199 100644 --- a/src/mlia/target/tosa/events.py +++ b/src/mlia/target/tosa/events.py @@ -1,12 +1,15 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA advisor events.""" +from __future__ import annotations + from dataclasses import dataclass from pathlib import Path from mlia.core.events import Event from mlia.core.events import EventDispatcher from mlia.target.tosa.config import TOSAConfiguration +from mlia.target.tosa.reporters import MetadataDisplay @dataclass @@ -15,6 +18,7 @@ class TOSAAdvisorStartedEvent(Event): model: Path device: TOSAConfiguration + tosa_metadata: MetadataDisplay | None class TOSAAdvisorEventHandler(EventDispatcher): diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py index 1037ba1..7f80f77 100644 --- a/src/mlia/target/tosa/handlers.py +++ b/src/mlia/target/tosa/handlers.py @@ -27,10 +27,11 @@ class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None: """Handle TOSAAdvisorStartedEvent event.""" self.reporter.submit(event.device) + self.reporter.submit(event.tosa_metadata) def on_collected_data(self, event: CollectedDataEvent) -> None: """Handle CollectedDataEvent event.""" data_item = event.data_item if isinstance(data_item, TOSACompatibilityInfo): - self.reporter.submit(data_item.operators, delay_print=True) + self.reporter.submit(data_item, delay_print=True) diff --git a/src/mlia/target/tosa/metadata.py b/src/mlia/target/tosa/metadata.py new file mode 100644 index 0000000..5575207 --- /dev/null +++ b/src/mlia/target/tosa/metadata.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA package metadata.""" +from mlia.core.metadata import Metadata + + +class TOSAMetadata(Metadata): # pylint: disable=too-few-public-methods + """TOSA metadata.""" diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py index 01fbb97..283f61f 100644 --- a/src/mlia/target/tosa/reporters.py +++ b/src/mlia/target/tosa/reporters.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reports module.""" from __future__ import annotations @@ -7,20 +7,41 @@ from typing import Any from typing import Callable from mlia.backend.tosa_checker.compat import Operator +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.advice_generation import Advice +from mlia.core.metadata import MLIAMetadata +from mlia.core.metadata import ModelMetadata from mlia.core.reporters import report_advice from mlia.core.reporting import Cell from mlia.core.reporting import Column +from mlia.core.reporting import CompoundReport from mlia.core.reporting import Format from mlia.core.reporting import NestedReport from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table from mlia.target.tosa.config import TOSAConfiguration +from mlia.target.tosa.metadata import TOSAMetadata from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of +class MetadataDisplay: # pylint: disable=too-few-public-methods + """TOSA metadata.""" + + def __init__( + self, + tosa_meta: TOSAMetadata, + mlia_meta: MLIAMetadata, + model_meta: ModelMetadata, + ) -> None: + """Init TOSAMetadata.""" + self.tosa_version = tosa_meta.version + self.mlia_version = mlia_meta.version + self.model_check_sum = model_meta.checksum + self.model_name = model_meta.model_name + + def report_device(device: TOSAConfiguration) -> Report: """Generate report for the device.""" return NestedReport( @@ -32,6 +53,34 @@ def report_device(device: TOSAConfiguration) -> Report: ) +def report_metadata(data: MetadataDisplay) -> Report: + """Generate report for the package version.""" + return NestedReport( + "Metadata", + "metadata", + [ + ReportItem( + "TOSA checker", + alias="tosa-checker", + nested_items=[ReportItem("version", "version", data.tosa_version)], + ), + ReportItem( + "MLIA", + "MLIA", + nested_items=[ReportItem("version", "version", data.mlia_version)], + ), + ReportItem( + "Model", + "Model", + nested_items=[ + ReportItem("name", "name", data.model_name), + ReportItem("checksum", "checksum", data.model_check_sum), + ], + ), + ], + ) + + def report_tosa_operators(ops: list[Operator]) -> Report: """Generate report for the operators.""" return Table( @@ -69,6 +118,42 @@ def report_tosa_operators(ops: list[Operator]) -> Report: ) +def report_tosa_exception(exc: Exception | None) -> Report: + """Generate report for exception thrown by tosa.""" + return NestedReport( + "TOSA exception", + "exception", + [ + ReportItem("TOSA exception", alias="exception", value=repr(exc)), + ], + ) + + +def report_tosa_errors(err: list[str] | None) -> Report: + """Generate report for errors thrown by tosa.""" + message = "".join(err) if err else None + return NestedReport( + "TOSA stderr", + "stderr", + [ + ReportItem( + "TOSA stderr", + alias="stderr", + value=message, + ), + ], + ) + + +def report_tosa_compatibility(compat_info: TOSACompatibilityInfo) -> Report: + """Generate combined report for all compatibility info.""" + report_ops = report_tosa_operators(compat_info.operators) + report_exception = report_tosa_exception(compat_info.exception) + + report_errors = report_tosa_errors(compat_info.errors) + return CompoundReport([report_ops, report_exception, report_errors]) + + def tosa_formatters(data: Any) -> Callable[[Any], Report]: """Find appropriate formatter for the provided data.""" if is_list_of(data, Advice): @@ -77,7 +162,13 @@ def tosa_formatters(data: Any) -> Callable[[Any], Report]: if isinstance(data, TOSAConfiguration): return report_device + if isinstance(data, MetadataDisplay): + return report_metadata + if is_list_of(data, Operator): return report_tosa_operators + if isinstance(data, TOSACompatibilityInfo): + return report_tosa_compatibility + raise Exception(f"Unable to find appropriate formatter for {data}") diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py index cf7ad27..0659dcf 100644 --- a/src/mlia/utils/logging.py +++ b/src/mlia/utils/logging.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Logging utility functions.""" from __future__ import annotations @@ -54,10 +54,10 @@ def redirect_output( @contextmanager -def redirect_raw( - logger: logging.Logger, output: TextIO, log_level: int +def process_raw_output( + consumer: Callable[[str], None], output: TextIO ) -> Generator[None, None, None]: - """Redirect output using file descriptors.""" + """Process output on file descriptor level.""" with tempfile.TemporaryFile(mode="r+") as tmp: old_output_fd: int | None = None try: @@ -73,7 +73,21 @@ def redirect_raw( tmp.seek(0) for line in tmp.readlines(): - logger.log(log_level, line.rstrip()) + consumer(line) + + +@contextmanager +def redirect_raw( + logger: logging.Logger, output: TextIO, log_level: int +) -> Generator[None, None, None]: + """Redirect output using file descriptors.""" + + def consumer(line: str) -> None: + """Redirect output to the logger.""" + logger.log(log_level, line.rstrip()) + + with process_raw_output(consumer, output): + yield @contextmanager @@ -94,6 +108,19 @@ def redirect_raw_output( yield +@contextmanager +def capture_raw_output(output: TextIO) -> Generator[list[str], None, None]: + """Capture output as list of strings.""" + result: list[str] = [] + + def consumer(line: str) -> None: + """Save output for later processing.""" + result.append(line) + + with process_raw_output(consumer, output): + yield result + + class LogFilter(logging.Filter): """Configurable log filter.""" diff --git a/src/mlia/utils/misc.py b/src/mlia/utils/misc.py index de95448..dd2f007 100644 --- a/src/mlia/utils/misc.py +++ b/src/mlia/utils/misc.py @@ -1,9 +1,31 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Various util functions.""" +from importlib import metadata +from pathlib import Path + +from mlia.utils.filesystem import sha256 + + +class MetadataError(Exception): + """Metadata error.""" def yes(prompt: str) -> bool: """Return true if user confirms the action.""" response = input(f"{prompt} [y/n]: ") return response in ["y", "Y"] + + +def get_pkg_version(pkg_name: str) -> str: + """Return the version of python package.""" + try: + pkg_version = metadata.version(pkg_name) + except FileNotFoundError as exc: + raise MetadataError(exc) from exc + return pkg_version + + +def get_file_checksum(input_path: Path) -> str: + """Retrun the checksum of the input model.""" + return sha256(input_path) -- cgit v1.2.1