From 4eb3fef8e5876c69dc6bac70fdc010805d5b97f2 Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Tue, 13 Dec 2022 22:02:21 +0000 Subject: MLIA-741/2 Report test results - add version extraction function in compat.py - create Metadata, MLIAMetadata, TOSAMetadata and MetadataDisplay classes - update the reporting functions so tosa and mlia version will be displayed in output json - update unit test test_configure_and_get_tosa_advisor to mock the get_events function - update the copyright information of all changed/added files - handle exception and report to json when program crashes - write new context managers for capturing stderr and stdout - support reporting stderr to json output - support reporting model checksum and model name to json output - made changes in test_e2e.py handling {model_name} replacement in --output - add unit tests Change-Id: I6629fd1c5754378e6accd488217c83d87c7eb6f1 --- src/mlia/backend/tosa_checker/compat.py | 56 +++++++++++++++++--- src/mlia/core/metadata.py | 37 +++++++++++++ src/mlia/target/tosa/advisor.py | 14 ++++- src/mlia/target/tosa/events.py | 6 ++- src/mlia/target/tosa/handlers.py | 3 +- src/mlia/target/tosa/metadata.py | 8 +++ src/mlia/target/tosa/reporters.py | 93 ++++++++++++++++++++++++++++++++- src/mlia/utils/logging.py | 37 +++++++++++-- src/mlia/utils/misc.py | 24 ++++++++- tests/test_backend_tosa_compat.py | 38 +++++++++----- tests/test_target_tosa_advisor.py | 15 +++++- tests/test_target_tosa_reporters.py | 52 ++++++++++++++++++ tests/test_utils_logging.py | 15 +++++- tests/test_utils_misc.py | 10 +++- tests_e2e/test_e2e.py | 15 ++++-- 15 files changed, 383 insertions(+), 40 deletions(-) create mode 100644 src/mlia/core/metadata.py create mode 100644 src/mlia/target/tosa/metadata.py create mode 100644 tests/test_target_tosa_reporters.py diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py index bd21774..81f3015 100644 --- a/src/mlia/backend/tosa_checker/compat.py +++ b/src/mlia/backend/tosa_checker/compat.py @@ -1,8 +1,9 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA compatibility module.""" from __future__ import annotations +import sys from dataclasses import dataclass from typing import Any from typing import cast @@ -10,6 +11,7 @@ from typing import Protocol from mlia.backend.errors import BackendUnavailableError from mlia.core.typing import PathOrFileLike +from mlia.utils.logging import capture_raw_output class TOSAChecker(Protocol): @@ -37,25 +39,65 @@ class TOSACompatibilityInfo: tosa_compatible: bool operators: list[Operator] + exception: Exception | None = None + errors: list[str] | None = None + std_out: list[str] | None = None def get_tosa_compatibility_info( tflite_model_path: PathOrFileLike, ) -> TOSACompatibilityInfo: """Return list of the operators.""" - checker = get_tosa_checker(tflite_model_path) + # Capture the possible exception in running get_tosa_checker + try: + with capture_raw_output(sys.stdout) as std_output_pkg, capture_raw_output( + sys.stderr + ) as stderr_output_pkg: + checker = get_tosa_checker(tflite_model_path) + except Exception as exc: # pylint: disable=broad-except + return TOSACompatibilityInfo( + tosa_compatible=False, + operators=[], + exception=exc, + errors=None, + std_out=None, + ) + # Capture the possible BackendUnavailableError when tosa-checker is not available if checker is None: raise BackendUnavailableError( "Backend tosa-checker is not available", "tosa-checker" ) - ops = [ - Operator(item.location, item.name, item.is_tosa_compatible) - for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access - ] + # Capture the possible exception when checking ops compatibility + try: + with capture_raw_output(sys.stdout) as std_output_ops, capture_raw_output( + sys.stderr + ) as stderr_output_ops: + ops = [ + Operator(item.location, item.name, item.is_tosa_compatible) + for item in checker._get_tosa_compatibility_for_ops() # pylint: disable=protected-access + ] + except Exception as exc: # pylint: disable=broad-except + return TOSACompatibilityInfo( + tosa_compatible=False, + operators=[], + exception=exc, + errors=None, + std_out=None, + ) - return TOSACompatibilityInfo(checker.is_tosa_compatible(), ops) + # Concatenate all possbile stderr/stdout + stderr_output = stderr_output_pkg + stderr_output_ops + std_output = std_output_pkg + std_output_ops + + return TOSACompatibilityInfo( + tosa_compatible=checker.is_tosa_compatible(), + operators=ops, + exception=None, + errors=stderr_output, + std_out=std_output, + ) def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: diff --git a/src/mlia/core/metadata.py b/src/mlia/core/metadata.py new file mode 100644 index 0000000..f0a0e03 --- /dev/null +++ b/src/mlia/core/metadata.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Classes for metadata.""" +from pathlib import Path + +from mlia.utils.misc import get_file_checksum +from mlia.utils.misc import get_pkg_version + + +class Metadata: # pylint: disable=too-few-public-methods + """Base class metadata.""" + + def __init__(self, data_name: str) -> None: + """Init Metadata.""" + self.version = self.get_version(data_name) + + def get_version(self, data_name: str) -> str: + """Get version of the python package.""" + return get_pkg_version(data_name) + + +class MLIAMetadata(Metadata): # pylint: disable=too-few-public-methods + """MLIA metadata.""" + + +class ModelMetadata: # pylint: disable=too-few-public-methods + """Model metadata.""" + + def __init__(self, path_name: Path) -> None: + """Init ModelMetadata.""" + self.model_name = path_name.name + self.path_name = path_name + self.checksum = self.get_checksum() + + def get_checksum(self) -> str: + """Get checksum of the model.""" + return get_file_checksum(self.path_name) diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 4851113..0da44db 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -16,12 +16,16 @@ from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event +from mlia.core.metadata import MLIAMetadata +from mlia.core.metadata import ModelMetadata from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration from mlia.target.tosa.data_analysis import TOSADataAnalyzer from mlia.target.tosa.data_collection import TOSAOperatorCompatibility from mlia.target.tosa.events import TOSAAdvisorStartedEvent from mlia.target.tosa.handlers import TOSAEventHandler +from mlia.target.tosa.metadata import TOSAMetadata +from mlia.target.tosa.reporters import MetadataDisplay class TOSAInferenceAdvisor(DefaultInferenceAdvisor): @@ -61,7 +65,15 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): target_profile = self.get_target_profile(context) return [ - TOSAAdvisorStartedEvent(model, TOSAConfiguration(target_profile)), + TOSAAdvisorStartedEvent( + model, + TOSAConfiguration(target_profile), + MetadataDisplay( + TOSAMetadata("tosa-checker"), + MLIAMetadata("mlia"), + ModelMetadata(model), + ), + ) ] diff --git a/src/mlia/target/tosa/events.py b/src/mlia/target/tosa/events.py index 67d499d..cbfd199 100644 --- a/src/mlia/target/tosa/events.py +++ b/src/mlia/target/tosa/events.py @@ -1,12 +1,15 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA advisor events.""" +from __future__ import annotations + from dataclasses import dataclass from pathlib import Path from mlia.core.events import Event from mlia.core.events import EventDispatcher from mlia.target.tosa.config import TOSAConfiguration +from mlia.target.tosa.reporters import MetadataDisplay @dataclass @@ -15,6 +18,7 @@ class TOSAAdvisorStartedEvent(Event): model: Path device: TOSAConfiguration + tosa_metadata: MetadataDisplay | None class TOSAAdvisorEventHandler(EventDispatcher): diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py index 1037ba1..7f80f77 100644 --- a/src/mlia/target/tosa/handlers.py +++ b/src/mlia/target/tosa/handlers.py @@ -27,10 +27,11 @@ class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None: """Handle TOSAAdvisorStartedEvent event.""" self.reporter.submit(event.device) + self.reporter.submit(event.tosa_metadata) def on_collected_data(self, event: CollectedDataEvent) -> None: """Handle CollectedDataEvent event.""" data_item = event.data_item if isinstance(data_item, TOSACompatibilityInfo): - self.reporter.submit(data_item.operators, delay_print=True) + self.reporter.submit(data_item, delay_print=True) diff --git a/src/mlia/target/tosa/metadata.py b/src/mlia/target/tosa/metadata.py new file mode 100644 index 0000000..5575207 --- /dev/null +++ b/src/mlia/target/tosa/metadata.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""TOSA package metadata.""" +from mlia.core.metadata import Metadata + + +class TOSAMetadata(Metadata): # pylint: disable=too-few-public-methods + """TOSA metadata.""" diff --git a/src/mlia/target/tosa/reporters.py b/src/mlia/target/tosa/reporters.py index 01fbb97..283f61f 100644 --- a/src/mlia/target/tosa/reporters.py +++ b/src/mlia/target/tosa/reporters.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reports module.""" from __future__ import annotations @@ -7,20 +7,41 @@ from typing import Any from typing import Callable from mlia.backend.tosa_checker.compat import Operator +from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo from mlia.core.advice_generation import Advice +from mlia.core.metadata import MLIAMetadata +from mlia.core.metadata import ModelMetadata from mlia.core.reporters import report_advice from mlia.core.reporting import Cell from mlia.core.reporting import Column +from mlia.core.reporting import CompoundReport from mlia.core.reporting import Format from mlia.core.reporting import NestedReport from mlia.core.reporting import Report from mlia.core.reporting import ReportItem from mlia.core.reporting import Table from mlia.target.tosa.config import TOSAConfiguration +from mlia.target.tosa.metadata import TOSAMetadata from mlia.utils.console import style_improvement from mlia.utils.types import is_list_of +class MetadataDisplay: # pylint: disable=too-few-public-methods + """TOSA metadata.""" + + def __init__( + self, + tosa_meta: TOSAMetadata, + mlia_meta: MLIAMetadata, + model_meta: ModelMetadata, + ) -> None: + """Init TOSAMetadata.""" + self.tosa_version = tosa_meta.version + self.mlia_version = mlia_meta.version + self.model_check_sum = model_meta.checksum + self.model_name = model_meta.model_name + + def report_device(device: TOSAConfiguration) -> Report: """Generate report for the device.""" return NestedReport( @@ -32,6 +53,34 @@ def report_device(device: TOSAConfiguration) -> Report: ) +def report_metadata(data: MetadataDisplay) -> Report: + """Generate report for the package version.""" + return NestedReport( + "Metadata", + "metadata", + [ + ReportItem( + "TOSA checker", + alias="tosa-checker", + nested_items=[ReportItem("version", "version", data.tosa_version)], + ), + ReportItem( + "MLIA", + "MLIA", + nested_items=[ReportItem("version", "version", data.mlia_version)], + ), + ReportItem( + "Model", + "Model", + nested_items=[ + ReportItem("name", "name", data.model_name), + ReportItem("checksum", "checksum", data.model_check_sum), + ], + ), + ], + ) + + def report_tosa_operators(ops: list[Operator]) -> Report: """Generate report for the operators.""" return Table( @@ -69,6 +118,42 @@ def report_tosa_operators(ops: list[Operator]) -> Report: ) +def report_tosa_exception(exc: Exception | None) -> Report: + """Generate report for exception thrown by tosa.""" + return NestedReport( + "TOSA exception", + "exception", + [ + ReportItem("TOSA exception", alias="exception", value=repr(exc)), + ], + ) + + +def report_tosa_errors(err: list[str] | None) -> Report: + """Generate report for errors thrown by tosa.""" + message = "".join(err) if err else None + return NestedReport( + "TOSA stderr", + "stderr", + [ + ReportItem( + "TOSA stderr", + alias="stderr", + value=message, + ), + ], + ) + + +def report_tosa_compatibility(compat_info: TOSACompatibilityInfo) -> Report: + """Generate combined report for all compatibility info.""" + report_ops = report_tosa_operators(compat_info.operators) + report_exception = report_tosa_exception(compat_info.exception) + + report_errors = report_tosa_errors(compat_info.errors) + return CompoundReport([report_ops, report_exception, report_errors]) + + def tosa_formatters(data: Any) -> Callable[[Any], Report]: """Find appropriate formatter for the provided data.""" if is_list_of(data, Advice): @@ -77,7 +162,13 @@ def tosa_formatters(data: Any) -> Callable[[Any], Report]: if isinstance(data, TOSAConfiguration): return report_device + if isinstance(data, MetadataDisplay): + return report_metadata + if is_list_of(data, Operator): return report_tosa_operators + if isinstance(data, TOSACompatibilityInfo): + return report_tosa_compatibility + raise Exception(f"Unable to find appropriate formatter for {data}") diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py index cf7ad27..0659dcf 100644 --- a/src/mlia/utils/logging.py +++ b/src/mlia/utils/logging.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Logging utility functions.""" from __future__ import annotations @@ -54,10 +54,10 @@ def redirect_output( @contextmanager -def redirect_raw( - logger: logging.Logger, output: TextIO, log_level: int +def process_raw_output( + consumer: Callable[[str], None], output: TextIO ) -> Generator[None, None, None]: - """Redirect output using file descriptors.""" + """Process output on file descriptor level.""" with tempfile.TemporaryFile(mode="r+") as tmp: old_output_fd: int | None = None try: @@ -73,7 +73,21 @@ def redirect_raw( tmp.seek(0) for line in tmp.readlines(): - logger.log(log_level, line.rstrip()) + consumer(line) + + +@contextmanager +def redirect_raw( + logger: logging.Logger, output: TextIO, log_level: int +) -> Generator[None, None, None]: + """Redirect output using file descriptors.""" + + def consumer(line: str) -> None: + """Redirect output to the logger.""" + logger.log(log_level, line.rstrip()) + + with process_raw_output(consumer, output): + yield @contextmanager @@ -94,6 +108,19 @@ def redirect_raw_output( yield +@contextmanager +def capture_raw_output(output: TextIO) -> Generator[list[str], None, None]: + """Capture output as list of strings.""" + result: list[str] = [] + + def consumer(line: str) -> None: + """Save output for later processing.""" + result.append(line) + + with process_raw_output(consumer, output): + yield result + + class LogFilter(logging.Filter): """Configurable log filter.""" diff --git a/src/mlia/utils/misc.py b/src/mlia/utils/misc.py index de95448..dd2f007 100644 --- a/src/mlia/utils/misc.py +++ b/src/mlia/utils/misc.py @@ -1,9 +1,31 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Various util functions.""" +from importlib import metadata +from pathlib import Path + +from mlia.utils.filesystem import sha256 + + +class MetadataError(Exception): + """Metadata error.""" def yes(prompt: str) -> bool: """Return true if user confirms the action.""" response = input(f"{prompt} [y/n]: ") return response in ["y", "Y"] + + +def get_pkg_version(pkg_name: str) -> str: + """Return the version of python package.""" + try: + pkg_version = metadata.version(pkg_name) + except FileNotFoundError as exc: + raise MetadataError(exc) from exc + return pkg_version + + +def get_file_checksum(input_path: Path) -> str: + """Retrun the checksum of the input model.""" + return sha256(input_path) diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py index 52aaa44..5a80b4b 100644 --- a/tests/test_backend_tosa_compat.py +++ b/tests/test_backend_tosa_compat.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA compatibility.""" from __future__ import annotations @@ -39,12 +39,13 @@ def test_compatibility_check_should_fail_if_checker_not_available( @pytest.mark.parametrize( - "is_tosa_compatible, operators, expected_result", + "is_tosa_compatible, operators, exception, expected_result", [ [ True, [], - TOSACompatibilityInfo(True, []), + None, + TOSACompatibilityInfo(True, [], None, None, None), ], [ True, @@ -55,18 +56,16 @@ def test_compatibility_check_should_fail_if_checker_not_available( is_tosa_compatible=True, ) ], - TOSACompatibilityInfo(True, [Operator("op_location", "op_name", True)]), + None, + TOSACompatibilityInfo( + True, [Operator("op_location", "op_name", True)], None, [], [] + ), ], [ False, - [ - SimpleNamespace( - location="op_location", - name="op_name", - is_tosa_compatible=False, - ) - ], - TOSACompatibilityInfo(False, [Operator("op_location", "op_name", False)]), + [], + ValueError("error_test"), + TOSACompatibilityInfo(False, [], ValueError("error_test"), [], []), ], ], ) @@ -75,6 +74,7 @@ def test_get_tosa_compatibility_info( test_tflite_model: Path, is_tosa_compatible: bool, operators: Any, + exception: Exception | None, expected_result: TOSACompatibilityInfo, ) -> None: """Test getting TOSA compatibility information.""" @@ -83,7 +83,17 @@ def test_get_tosa_compatibility_info( mock_checker._get_tosa_compatibility_for_ops.return_value = ( # pylint: disable=protected-access operators ) - + if exception: + mock_checker._get_tosa_compatibility_for_ops.side_effect = ( # pylint: disable=protected-access + exception + ) replace_get_tosa_checker_with_mock(monkeypatch, mock_checker) - assert get_tosa_compatibility_info(test_tflite_model) == expected_result + returned_compatibility_info = get_tosa_compatibility_info(test_tflite_model) + assert repr(returned_compatibility_info.exception) == repr( + expected_result.exception + ) + assert ( + returned_compatibility_info.tosa_compatible == expected_result.tosa_compatible + ) + assert returned_compatibility_info.operators == expected_result.operators diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index 32a6b77..9646c96 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -1,7 +1,10 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA advisor.""" from pathlib import Path +from unittest.mock import MagicMock + +import pytest from mlia.core.context import ExecutionContext from mlia.core.workflow import DefaultWorkflowExecutor @@ -9,15 +12,23 @@ from mlia.target.tosa.advisor import configure_and_get_tosa_advisor from mlia.target.tosa.advisor import TOSAInferenceAdvisor -def test_configure_and_get_tosa_advisor(test_tflite_model: Path) -> None: +def test_configure_and_get_tosa_advisor( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path +) -> None: """Test TOSA advisor configuration.""" ctx = ExecutionContext() + get_events_mock = MagicMock() + monkeypatch.setattr( + "mlia.target.tosa.advisor.TOSAInferenceAdvisor.get_events", + MagicMock(return_value=get_events_mock), + ) advisor = configure_and_get_tosa_advisor(ctx, "tosa", test_tflite_model) workflow = advisor.configure(ctx) assert isinstance(advisor, TOSAInferenceAdvisor) + assert advisor.get_events(ctx) == get_events_mock assert ctx.event_handlers is not None assert ctx.config_parameters == { "tosa_inference_advisor": { diff --git a/tests/test_target_tosa_reporters.py b/tests/test_target_tosa_reporters.py new file mode 100644 index 0000000..59da270 --- /dev/null +++ b/tests/test_target_tosa_reporters.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for tosa-checker reporters.""" +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mlia.core.metadata import MLIAMetadata +from mlia.core.metadata import ModelMetadata +from mlia.core.reporting import Report +from mlia.target.tosa.config import TOSAConfiguration +from mlia.target.tosa.metadata import TOSAMetadata +from mlia.target.tosa.reporters import MetadataDisplay +from mlia.target.tosa.reporters import report_device +from mlia.target.tosa.reporters import tosa_formatters + + +def test_tosa_report_device() -> None: + """Test function report_device().""" + report = report_device(TOSAConfiguration("tosa")) + assert report.to_plain_text() + + +def test_tosa_formatters( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path +) -> None: + """Test function tosa_formatters() with valid input.""" + mock_version = MagicMock() + monkeypatch.setattr( + "mlia.core.metadata.get_pkg_version", + MagicMock(return_value=mock_version), + ) + + data = MetadataDisplay( + TOSAMetadata("tosa-checker"), + MLIAMetadata("mlia"), + ModelMetadata(test_tflite_model), + ) + formatter = tosa_formatters(data) + report = formatter(data) + assert data.tosa_version == mock_version + assert isinstance(report, Report) + + +def test_tosa_formatters_invalid_data() -> None: + """Test tosa_formatters() with invalid input.""" + with pytest.raises( + Exception, + match=r"^Unable to find appropriate formatter for .*", + ): + tosa_formatters(12) diff --git a/tests/test_utils_logging.py b/tests/test_utils_logging.py index ac835c6..c02e8b0 100644 --- a/tests/test_utils_logging.py +++ b/tests/test_utils_logging.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the logging utility functions.""" from __future__ import annotations @@ -13,6 +13,7 @@ from unittest.mock import MagicMock import pytest +from mlia.utils.logging import capture_raw_output from mlia.utils.logging import create_log_handler from mlia.utils.logging import redirect_output from mlia.utils.logging import redirect_raw_output @@ -84,3 +85,15 @@ def test_output_redirection(redirect_context_manager: Callable) -> None: print("after redirect") logger_mock.log.assert_called_once_with(logging.INFO, "output redirected") + + +def test_output_and_error_capture() -> None: + """Test output/error capturing.""" + with capture_raw_output(sys.stdout) as std_output, capture_raw_output( + sys.stderr + ) as stderr_output: + print("hello from stdout") + print("hello from stderr", file=sys.stderr) + + assert std_output == ["hello from stdout\n"] + assert stderr_output == ["hello from stderr\n"] diff --git a/tests/test_utils_misc.py b/tests/test_utils_misc.py index 011d09e..ae91850 100644 --- a/tests/test_utils_misc.py +++ b/tests/test_utils_misc.py @@ -1,10 +1,11 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for misc util functions.""" from unittest.mock import MagicMock import pytest +from mlia.utils.misc import get_pkg_version from mlia.utils.misc import yes @@ -23,3 +24,10 @@ def test_yes( """Test yes function.""" monkeypatch.setattr("builtins.input", MagicMock(return_value=response)) assert yes("some_prompt") == expected_result + + +@pytest.mark.parametrize("response", ["some version", FileNotFoundError()]) +def test_get_pkg_version(monkeypatch: pytest.MonkeyPatch, response: str) -> None: + """Test get_tosa_version.""" + monkeypatch.setattr("importlib.metadata.version", MagicMock(return_value=response)) + assert get_pkg_version("any name") == response diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index 26f5d29..35fd707 100644 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -218,11 +218,16 @@ def get_all_commands_combinations(executions: Any) -> Generator[list[str], None, ExecutionConfiguration.from_dict(exec_info) for exec_info in executions ) - return ( - command_combination - for exec_config in exec_configs - for command_combination in exec_config.all_combinations - ) + parser = get_args_parser() + for exec_config in exec_configs: + for command_combination in exec_config.all_combinations: + for idx, param in enumerate(command_combination): + if "{model_name}" in param: + args = parser.parse_args(command_combination) + model_name = Path(args.model).stem + param = param.replace("{model_name}", model_name) + command_combination[idx] = param + yield command_combination def check_args(args: list[str], no_skip: bool) -> None: -- cgit v1.2.1