diff options
28 files changed, 394 insertions, 488 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py index 2cabf37..8105276 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -9,7 +9,6 @@ from typing import Any from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory -from mlia.core.common import FormattedFilePath from mlia.core.context import ExecutionContext from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor @@ -24,7 +23,6 @@ def get_advice( model: str | Path, category: set[str], optimization_targets: list[dict[str, Any]] | None = None, - output: FormattedFilePath | None = None, context: ExecutionContext | None = None, backends: list[str] | None = None, ) -> None: @@ -42,8 +40,6 @@ def get_advice( category "compatibility" is used by default. :param optimization_targets: optional model optimization targets that could be used for generating advice in "optimization" category. - :param output: path to the report file. If provided, MLIA will save - report in this location. :param context: optional parameter which represents execution context, could be used for advanced use cases :param backends: A list of backends that should be used for the given @@ -57,11 +53,9 @@ def get_advice( >>> get_advice("ethos-u55-256", "path/to/the/model", {"optimization", "compatibility"}) - Getting the advice for the category "performance" and save result report in file - "report.json" + Getting the advice for the category "performance". - >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"}, - output=FormattedFilePath("report.json") + >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"}) """ advice_category = AdviceCategory.from_string(category) @@ -76,7 +70,6 @@ def get_advice( context, target_profile, model, - output, optimization_targets=optimization_targets, backends=backends, ) @@ -88,7 +81,6 @@ def get_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: FormattedFilePath | None = None, **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" @@ -109,6 +101,5 @@ def get_advisor( context, target_profile, model, - output, **extra_args, ) diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py index 81f3015..1c410d3 100644 --- a/src/mlia/backend/tosa_checker/compat.py +++ b/src/mlia/backend/tosa_checker/compat.py @@ -5,12 +5,12 @@ from __future__ import annotations import sys from dataclasses import dataclass +from pathlib import Path from typing import Any from typing import cast from typing import Protocol from mlia.backend.errors import BackendUnavailableError -from mlia.core.typing import PathOrFileLike from mlia.utils.logging import capture_raw_output @@ -45,7 +45,7 @@ class TOSACompatibilityInfo: def get_tosa_compatibility_info( - tflite_model_path: PathOrFileLike, + tflite_model_path: str | Path, ) -> TOSACompatibilityInfo: """Return list of the operators.""" # Capture the possible exception in running get_tosa_checker @@ -100,7 +100,7 @@ def get_tosa_compatibility_info( ) -def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None: +def get_tosa_checker(tflite_model_path: str | Path) -> TOSAChecker | None: """Return instance of the TOSA checker.""" try: import tosa_checker as tc # pylint: disable=import-outside-toplevel diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index d2242ba..c17d571 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -7,10 +7,10 @@ functionality. Before running them from scripts 'logging' module should be configured. Function 'setup_logging' from module -'mli.cli.logging' could be used for that, e.g. +'mli.core.logging' could be used for that, e.g. >>> from mlia.api import ExecutionContext ->>> from mlia.cli.logging import setup_logging +>>> from mlia.core.logging import setup_logging >>> setup_logging(verbose=True) >>> import mlia.cli.commands as mlia >>> mlia.check(ExecutionContext(), "ethos-u55-256", @@ -27,7 +27,6 @@ from mlia.cli.command_validators import validate_backend from mlia.cli.command_validators import validate_check_target_profile from mlia.cli.config import get_installation_manager from mlia.cli.options import parse_optimization_parameters -from mlia.cli.options import parse_output_parameters from mlia.utils.console import create_section_header logger = logging.getLogger(__name__) @@ -41,8 +40,6 @@ def check( model: str | None = None, compatibility: bool = False, performance: bool = False, - output: Path | None = None, - json: bool = False, backend: list[str] | None = None, ) -> None: """Generate a full report on the input model. @@ -61,7 +58,6 @@ def check( :param model: path to the Keras model :param compatibility: flag that identifies whether to run compatibility checks :param performance: flag that identifies whether to run performance checks - :param output: path to the file where the report will be saved :param backend: list of the backends to use for evaluation Example: @@ -69,18 +65,15 @@ def check( and operator compatibility. >>> from mlia.api import ExecutionContext - >>> from mlia.cli.logging import setup_logging + >>> from mlia.core.logging import setup_logging >>> setup_logging() >>> from mlia.cli.commands import check >>> check(ExecutionContext(), "ethos-u55-256", - "model.h5", compatibility=True, performance=True, - output="report.json") + "model.h5", compatibility=True, performance=True) """ if not model: raise Exception("Model is not provided") - formatted_output = parse_output_parameters(output, json) - # Set category based on checks to perform (i.e. "compatibility" and/or # "performance"). # If no check type is specified, "compatibility" is the default category. @@ -98,7 +91,6 @@ def check( target_profile, model, category, - output=formatted_output, context=ctx, backends=validated_backend, ) @@ -113,8 +105,6 @@ def optimize( # pylint: disable=too-many-arguments pruning_target: float | None, clustering_target: int | None, layers_to_optimize: list[str] | None = None, - output: Path | None = None, - json: bool = False, backend: list[str] | None = None, ) -> None: """Show the performance improvements (if any) after applying the optimizations. @@ -133,15 +123,13 @@ def optimize( # pylint: disable=too-many-arguments :param pruning_target: pruning optimization target :param layers_to_optimize: list of the layers of the model which should be optimized, if None then all layers are used - :param output: path to the file where the report will be saved - :param json: set the output format to json :param backend: list of the backends to use for evaluation Example: Run command for the target profile ethos-u55-256 and the provided TensorFlow Lite model and print report on the standard output - >>> from mlia.cli.logging import setup_logging + >>> from mlia.core.logging import setup_logging >>> from mlia.api import ExecutionContext >>> setup_logging() >>> from mlia.cli.commands import optimize @@ -161,7 +149,6 @@ def optimize( # pylint: disable=too-many-arguments ) ) - formatted_output = parse_output_parameters(output, json) validated_backend = validate_backend(target_profile, backend) get_advice( @@ -169,7 +156,6 @@ def optimize( # pylint: disable=too-many-arguments model, {"optimization"}, optimization_targets=opt_params, - output=formatted_output, context=ctx, backends=validated_backend, ) diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 1102d45..7ce7dc9 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -19,7 +19,6 @@ from mlia.cli.commands import check from mlia.cli.commands import optimize from mlia.cli.common import CommandInfo from mlia.cli.helpers import CLIActionResolver -from mlia.cli.logging import setup_logging from mlia.cli.options import add_backend_install_options from mlia.cli.options import add_backend_options from mlia.cli.options import add_backend_uninstall_options @@ -30,9 +29,11 @@ from mlia.cli.options import add_model_options from mlia.cli.options import add_multi_optimization_options from mlia.cli.options import add_output_options from mlia.cli.options import add_target_options +from mlia.cli.options import get_output_format from mlia.core.context import ExecutionContext from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError +from mlia.core.logging import setup_logging from mlia.target.registry import registry as target_registry @@ -162,10 +163,13 @@ def setup_context( ctx = ExecutionContext( verbose="debug" in args and args.debug, action_resolver=CLIActionResolver(vars(args)), + output_format=get_output_format(args), ) + setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format) + # these parameters should not be passed into command function - skipped_params = ["func", "command", "debug"] + skipped_params = ["func", "command", "debug", "json"] # pass these parameters only if command expects them expected_params = [context_var_name] @@ -186,7 +190,6 @@ def setup_context( def run_command(args: argparse.Namespace) -> int: """Run command.""" ctx, func_args = setup_context(args) - setup_logging(ctx.logs_path, ctx.verbose) logger.debug( "*** This is the beginning of the command '%s' execution ***", args.command diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 5aca3b3..e01f107 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -12,7 +12,7 @@ from mlia.cli.config import DEFAULT_CLUSTERING_TARGET from mlia.cli.config import DEFAULT_PRUNING_TARGET from mlia.cli.config import get_available_backends from mlia.cli.config import is_corstone_backend -from mlia.core.common import FormattedFilePath +from mlia.core.typing import OutputFormat from mlia.utils.filesystem import get_supported_profile_names @@ -89,16 +89,9 @@ def add_output_options(parser: argparse.ArgumentParser) -> None: """Add output specific options.""" output_group = parser.add_argument_group("output options") output_group.add_argument( - "-o", - "--output", - type=Path, - help=("Name of the file where the report will be saved."), - ) - - output_group.add_argument( "--json", action="store_true", - help=("Format to use for the output (requires --output argument to be set)."), + help=("Print the output in JSON format."), ) @@ -209,22 +202,6 @@ def add_backend_options( ) -def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None: - """Parse and return path and file format as FormattedFilePath.""" - if not path and json: - raise argparse.ArgumentError( - None, - "To enable JSON output you need to specify the output path. " - "(e.g. --output out.json --json)", - ) - if not path: - return None - if json: - return FormattedFilePath(path, "json") - - return FormattedFilePath(path, "plain_text") - - def parse_optimization_parameters( pruning: bool = False, clustering: bool = False, @@ -301,3 +278,11 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]: for name in non_default for item in construct_param(params_name[name], device_args[name]) ] + + +def get_output_format(args: argparse.Namespace) -> OutputFormat: + """Return the OutputFormat depending on the CLI flags.""" + output_format: OutputFormat = "plain_text" + if "json" in args and args.json: + output_format = "json" + return output_format diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py index 53df001..baaed50 100644 --- a/src/mlia/core/common.py +++ b/src/mlia/core/common.py @@ -13,9 +13,6 @@ from enum import auto from enum import Flag from typing import Any -from mlia.core.typing import OutputFormat -from mlia.core.typing import PathOrFileLike - # This type is used as type alias for the items which are being passed around # in advisor workflow. There are no restrictions on the type of the # object. This alias used only to emphasize the nature of the input/output @@ -23,36 +20,6 @@ from mlia.core.typing import PathOrFileLike DataItem = Any -class FormattedFilePath: - """Class used to keep track of the format that a path points to.""" - - def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None: - """Init FormattedFilePath.""" - self._path = path - self._fmt = fmt - - @property - def fmt(self) -> OutputFormat: - """Return file format.""" - return self._fmt - - @property - def path(self) -> PathOrFileLike: - """Return file path.""" - return self._path - - def __eq__(self, other: object) -> bool: - """Check for equality with other objects.""" - if isinstance(other, FormattedFilePath): - return other.fmt == self.fmt and other.path == self.path - - return False - - def __repr__(self) -> str: - """Represent object.""" - return f"FormattedFilePath {self.path=}, {self.fmt=}" - - class AdviceCategory(Flag): """Advice category. diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py index 94aa885..f8442a3 100644 --- a/src/mlia/core/context.py +++ b/src/mlia/core/context.py @@ -23,6 +23,7 @@ from mlia.core.events import EventHandler from mlia.core.events import EventPublisher from mlia.core.helpers import ActionResolver from mlia.core.helpers import APIActionResolver +from mlia.core.typing import OutputFormat logger = logging.getLogger(__name__) @@ -68,6 +69,11 @@ class Context(ABC): def action_resolver(self) -> ActionResolver: """Return action resolver.""" + @property + @abstractmethod + def output_format(self) -> OutputFormat: + """Return the output format.""" + @abstractmethod def update( self, @@ -106,6 +112,7 @@ class ExecutionContext(Context): logs_dir: str = "logs", models_dir: str = "models", action_resolver: ActionResolver | None = None, + output_format: OutputFormat = "plain_text", ) -> None: """Init execution context. @@ -139,6 +146,7 @@ class ExecutionContext(Context): self.logs_dir = logs_dir self.models_dir = models_dir self._action_resolver = action_resolver or APIActionResolver() + self._output_format = output_format @property def working_dir(self) -> Path: @@ -197,6 +205,11 @@ class ExecutionContext(Context): """Return path to the logs directory.""" return self._working_dir_path / self.logs_dir + @property + def output_format(self) -> OutputFormat: + """Return the output format.""" + return self._output_format + def update( self, *, @@ -221,7 +234,8 @@ class ExecutionContext(Context): f"ExecutionContext: working_dir={self._working_dir_path}, " f"advice_category={category}, " f"config_parameters={self.config_parameters}, " - f"verbose={self.verbose}" + f"verbose={self.verbose}, " + f"output_format={self.output_format}" ) diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py index 6e50934..24d4881 100644 --- a/src/mlia/core/handlers.py +++ b/src/mlia/core/handlers.py @@ -9,7 +9,6 @@ from typing import Callable from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent -from mlia.core.common import FormattedFilePath from mlia.core.events import ActionFinishedEvent from mlia.core.events import ActionStartedEvent from mlia.core.events import AdviceStageFinishedEvent @@ -25,9 +24,11 @@ from mlia.core.events import EventDispatcher from mlia.core.events import ExecutionFailedEvent from mlia.core.events import ExecutionFinishedEvent from mlia.core.events import ExecutionStartedEvent +from mlia.core.mixins import ContextMixin +from mlia.core.reporting import JSONReporter from mlia.core.reporting import Report from mlia.core.reporting import Reporter -from mlia.core.typing import PathOrFileLike +from mlia.core.reporting import TextReporter from mlia.utils.console import create_section_header @@ -92,26 +93,27 @@ _ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started") _MODEL_ANALYSIS_MSG = create_section_header("Model Analysis") _MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results") _ADV_GENERATION_MSG = create_section_header("Advice Generation") -_REPORT_GENERATION_MSG = create_section_header("Report Generation") -class WorkflowEventsHandler(SystemEventsHandler): +class WorkflowEventsHandler(SystemEventsHandler, ContextMixin): """Event handler for the system events.""" + reporter: Reporter + def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output: FormattedFilePath | None = None, ) -> None: """Init event handler.""" - output_format = output.fmt if output else "plain_text" - self.reporter = Reporter(formatter_resolver, output_format) - self.output = output.path if output else None - + self.formatter_resolver = formatter_resolver self.advice: list[Advice] = [] def on_execution_started(self, event: ExecutionStartedEvent) -> None: """Handle ExecutionStarted event.""" + if self.context.output_format == "json": + self.reporter = JSONReporter(self.formatter_resolver) + else: + self.reporter = TextReporter(self.formatter_resolver) logger.info(_ADV_EXECUTION_STARTED) def on_execution_failed(self, event: ExecutionFailedEvent) -> None: @@ -132,12 +134,6 @@ class WorkflowEventsHandler(SystemEventsHandler): """Handle DataCollectorSkipped event.""" logger.info("Skipped: %s", event.reason) - @staticmethod - def report_generated(output: PathOrFileLike) -> None: - """Log report generation.""" - logger.info(_REPORT_GENERATION_MSG) - logger.info("Report(s) and advice list saved to: %s", output) - def on_data_analysis_stage_finished( self, event: DataAnalysisStageFinishedEvent ) -> None: @@ -160,7 +156,4 @@ class WorkflowEventsHandler(SystemEventsHandler): table_style="no_borders", ) - self.reporter.generate_report(self.output) - - if self.output is not None: - self.report_generated(self.output) + self.reporter.generate_report() diff --git a/src/mlia/cli/logging.py b/src/mlia/core/logging.py index 5c5c4b8..686f8ab 100644 --- a/src/mlia/cli/logging.py +++ b/src/mlia/core/logging.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI logging configuration.""" from __future__ import annotations @@ -8,18 +8,20 @@ import sys from pathlib import Path from typing import Iterable +from mlia.core.typing import OutputFormat from mlia.utils.logging import attach_handlers from mlia.utils.logging import create_log_handler -from mlia.utils.logging import LogFilter +from mlia.utils.logging import NoASCIIFormatter -_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s" +_CONSOLE_DEBUG_FORMAT = "%(name)s - %(levelname)s - %(message)s" _FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" def setup_logging( logs_dir: str | Path | None = None, verbose: bool = False, + output_format: OutputFormat = "plain_text", log_filename: str = "mlia.log", ) -> None: """Set up logging. @@ -30,6 +32,8 @@ def setup_logging( debug information. If the path is not provided then no log files will be created during execution :param verbose: enable extended logging for the tools loggers + :param output_format: specify the out format needed for setting up the right + logging system :param log_filename: name of the log file in the logs directory """ mlia_logger, tensorflow_logger, py_warnings_logger = ( @@ -42,7 +46,7 @@ def setup_logging( for logger in [mlia_logger, tensorflow_logger]: logger.setLevel(logging.DEBUG) - mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose) + mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose, output_format) attach_handlers(mlia_handlers, [mlia_logger]) tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose) @@ -53,28 +57,47 @@ def _get_mlia_handlers( logs_dir: str | Path | None, log_filename: str, verbose: bool, + output_format: OutputFormat, ) -> Iterable[logging.Handler]: """Get handlers for the MLIA loggers.""" - yield create_log_handler( - stream=sys.stdout, - log_level=logging.INFO, - ) - - if verbose: - mlia_verbose_handler = create_log_handler( - stream=sys.stdout, - log_level=logging.DEBUG, - log_format=_CONSOLE_DEBUG_FORMAT, - log_filter=LogFilter.equals(logging.DEBUG), + # MLIA needs output to standard output via the logging system only when the + # format is plain text. When the user specifies the "json" output format, + # MLIA disables completely the logging system for the console output and it + # relies on the print() function. This is needed because the output might + # be corrupted with spurious messages in the standard output. + if output_format == "plain_text": + if verbose: + log_level = logging.DEBUG + log_format = _CONSOLE_DEBUG_FORMAT + else: + log_level = logging.INFO + log_format = None + + # Create log handler for stdout + yield create_log_handler( + stream=sys.stdout, log_level=log_level, log_format=log_format + ) + else: + # In case of non plain text output, we need to inform the user if an + # error happens during execution. + yield create_log_handler( + stream=sys.stderr, + log_level=logging.ERROR, ) - yield mlia_verbose_handler + # If the logs directory is specified, MLIA stores all output (according to + # the logging level) into the file and removing the colouring of the + # console output. if logs_dir: + if verbose: + log_level = logging.DEBUG + else: + log_level = logging.INFO + yield create_log_handler( file_path=_get_log_file(logs_dir, log_filename), - log_level=logging.DEBUG, - log_format=_FILE_DEBUG_FORMAT, - log_filter=LogFilter.skip(logging.INFO), + log_level=log_level, + log_format=NoASCIIFormatter(fmt=_FILE_DEBUG_FORMAT), delay=True, ) diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index ad63d62..7b9ce5c 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -8,30 +8,21 @@ import logging from abc import ABC from abc import abstractmethod from collections import defaultdict -from contextlib import contextmanager -from contextlib import ExitStack from dataclasses import dataclass from enum import Enum from functools import partial -from io import TextIOWrapper -from pathlib import Path from textwrap import fill from textwrap import indent from typing import Any from typing import Callable -from typing import cast from typing import Collection -from typing import Generator from typing import Iterable import numpy as np -from mlia.core.typing import FileLike from mlia.core.typing import OutputFormat -from mlia.core.typing import PathOrFileLike from mlia.utils.console import apply_style from mlia.utils.console import produce_table -from mlia.utils.logging import LoggerWriter from mlia.utils.types import is_list_of logger = logging.getLogger(__name__) @@ -505,76 +496,48 @@ class CustomJSONEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, o) -def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: - """Produce report in json format.""" - json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder) - print(json_str, file=output) - - -def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None: - """Produce report in text format.""" - print(report.to_plain_text(**kwargs), file=output) +class Reporter(ABC): + """Reporter class.""" + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + ) -> None: + """Init reporter instance.""" + self.formatter_resolver = formatter_resolver + self.data: list[tuple[Any, Callable[[Any], Report]]] = [] -def produce_report( - data: Any, - formatter: Callable[[Any], Report], - fmt: OutputFormat = "plain_text", - output: PathOrFileLike | None = None, - **kwargs: Any, -) -> None: - """Produce report based on provided data.""" - # check if provided format value is supported - formats = {"json": json_reporter, "plain_text": text_reporter} - if fmt not in formats: - raise Exception(f"Unknown format {fmt}") + @abstractmethod + def submit(self, data_item: Any, **kwargs: Any) -> None: + """Submit data for the report.""" - if output is None: - output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO)) + def print_delayed(self) -> None: + """Print delayed reports.""" - with ExitStack() as exit_stack: - if isinstance(output, (str, Path)): - # open file and add it to the ExitStack context manager - # in that case it will be automatically closed - stream = exit_stack.enter_context(open(output, "w", encoding="utf-8")) - else: - stream = cast(TextIOWrapper, output) + def generate_report(self) -> None: + """Generate report.""" - # convert data into serializable form - formatted_data = formatter(data) - # find handler for the format - format_handler = formats[fmt] - # produce report in requested format - format_handler(formatted_data, stream, **kwargs) + @abstractmethod + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" -class Reporter: +class TextReporter(Reporter): """Reporter class.""" def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output_format: OutputFormat = "plain_text", - print_as_submitted: bool = True, ) -> None: """Init reporter instance.""" - self.formatter_resolver = formatter_resolver - self.output_format = output_format - self.print_as_submitted = print_as_submitted - - self.data: list[tuple[Any, Callable[[Any], Report]]] = [] + super().__init__(formatter_resolver) self.delayed: list[tuple[Any, Callable[[Any], Report]]] = [] + self.output_format: OutputFormat = "plain_text" def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None: """Submit data for the report.""" - if self.print_as_submitted and not delay_print: - produce_report( - data_item, - self.formatter_resolver(data_item), - fmt="plain_text", - **kwargs, - ) - formatter = _apply_format_parameters( self.formatter_resolver(data_item), self.output_format, **kwargs ) @@ -582,51 +545,70 @@ class Reporter: if delay_print: self.delayed.append((data_item, formatter)) + else: + self.produce_report( + data_item, + self.formatter_resolver(data_item), + **kwargs, + ) def print_delayed(self) -> None: """Print delayed reports.""" - if not self.delayed: - return + if self.delayed: + data, formatters = zip(*self.delayed) + self.produce_report( + data, + formatter=CompoundFormatter(formatters), + ) + self.delayed = [] - data, formatters = zip(*self.delayed) - produce_report( - data, - formatter=CompoundFormatter(formatters), - fmt="plain_text", + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" + formatted_data = formatter(data) + logger.info(formatted_data.to_plain_text(**kwargs)) + + +class JSONReporter(Reporter): + """Reporter class.""" + + def __init__( + self, + formatter_resolver: Callable[[Any], Callable[[Any], Report]], + ) -> None: + """Init reporter instance.""" + super().__init__(formatter_resolver) + self.output_format: OutputFormat = "json" + + def submit(self, data_item: Any, **kwargs: Any) -> None: + """Submit data for the report.""" + formatter = _apply_format_parameters( + self.formatter_resolver(data_item), self.output_format, **kwargs ) - self.delayed = [] + self.data.append((data_item, formatter)) - def generate_report(self, output: PathOrFileLike | None) -> None: + def generate_report(self) -> None: """Generate report.""" - already_printed = ( - self.print_as_submitted - and self.output_format == "plain_text" - and output is None - ) - if not self.data or already_printed: + if not self.data: return data, formatters = zip(*self.data) - produce_report( + self.produce_report( data, formatter=CompoundFormatter(formatters), - fmt=self.output_format, - output=output, ) - -@contextmanager -def get_reporter( - output_format: OutputFormat, - output: PathOrFileLike | None, - formatter_resolver: Callable[[Any], Callable[[Any], Report]], -) -> Generator[Reporter, None, None]: - """Get reporter and generate report.""" - reporter = Reporter(formatter_resolver, output_format) - - yield reporter - - reporter.generate_report(output) + def produce_report( + self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any + ) -> None: + """Produce report based on provided data.""" + formatted_data = formatter(data) + print( + json.dumps( + formatted_data.to_json(**kwargs), indent=4, cls=CustomJSONEncoder + ), + ) def _apply_format_parameters( diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py index 8244ff5..ea334c9 100644 --- a/src/mlia/core/typing.py +++ b/src/mlia/core/typing.py @@ -1,12 +1,7 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for custom type hints.""" -from pathlib import Path from typing import Literal -from typing import TextIO -from typing import Union -FileLike = TextIO -PathOrFileLike = Union[str, Path, FileLike] OutputFormat = Literal["plain_text", "json"] diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py index d862a86..9f8ac83 100644 --- a/src/mlia/core/workflow.py +++ b/src/mlia/core/workflow.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for executors. @@ -198,6 +198,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor): self.collectors, self.analyzers, self.producers, + self.context.event_handlers or [], ) if isinstance(comp, ContextMixin) ) diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py index b649f0d..1249d93 100644 --- a/src/mlia/target/cortex_a/advisor.py +++ b/src/mlia/target/cortex_a/advisor.py @@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory -from mlia.core.common import FormattedFilePath from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer @@ -67,12 +66,11 @@ def configure_and_get_cortexa_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: FormattedFilePath | None = None, **_extra_args: Any, ) -> InferenceAdvisor: """Create and configure Cortex-A advisor.""" if context.event_handlers is None: - context.event_handlers = [CortexAEventHandler(output)] + context.event_handlers = [CortexAEventHandler()] if context.config_parameters is None: context.config_parameters = _get_config_parameters(model, target_profile) diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py index d6acde5..1a74da7 100644 --- a/src/mlia/target/cortex_a/handlers.py +++ b/src/mlia/target/cortex_a/handlers.py @@ -5,7 +5,6 @@ from __future__ import annotations import logging -from mlia.core.common import FormattedFilePath from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo @@ -20,9 +19,9 @@ logger = logging.getLogger(__name__) class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler): """CLI event handler.""" - def __init__(self, output: FormattedFilePath | None = None) -> None: + def __init__(self) -> None: """Init event handler.""" - super().__init__(cortex_a_formatters, output) + super().__init__(cortex_a_formatters) def on_collected_data(self, event: CollectedDataEvent) -> None: """Handle CollectedDataEvent event.""" diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index 640c3e1..ce4e0fc 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory -from mlia.core.common import FormattedFilePath from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer @@ -126,12 +125,11 @@ def configure_and_get_ethosu_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: FormattedFilePath | None = None, **extra_args: Any, ) -> InferenceAdvisor: """Create and configure Ethos-U advisor.""" if context.event_handlers is None: - context.event_handlers = [EthosUEventHandler(output)] + context.event_handlers = [EthosUEventHandler()] if context.config_parameters is None: context.config_parameters = _get_config_parameters( diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py index 91f6015..9873014 100644 --- a/src/mlia/target/ethos_u/handlers.py +++ b/src/mlia/target/ethos_u/handlers.py @@ -6,7 +6,6 @@ from __future__ import annotations import logging from mlia.backend.vela.compat import Operators -from mlia.core.common import FormattedFilePath from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler from mlia.target.ethos_u.events import EthosUAdvisorEventHandler @@ -21,9 +20,9 @@ logger = logging.getLogger(__name__) class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): """CLI event handler.""" - def __init__(self, output: FormattedFilePath | None = None) -> None: + def __init__(self) -> None: """Init event handler.""" - super().__init__(ethos_u_formatters, output) + super().__init__(ethos_u_formatters) def on_collected_data(self, event: CollectedDataEvent) -> None: """Handle CollectedDataEvent event.""" diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 0da44db..b60e824 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceCategory from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor -from mlia.core.common import FormattedFilePath from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer @@ -81,12 +80,11 @@ def configure_and_get_tosa_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: FormattedFilePath | None = None, **_extra_args: Any, ) -> InferenceAdvisor: """Create and configure TOSA advisor.""" if context.event_handlers is None: - context.event_handlers = [TOSAEventHandler(output)] + context.event_handlers = [TOSAEventHandler()] if context.config_parameters is None: context.config_parameters = _get_config_parameters(model, target_profile) diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py index 7f80f77..f222823 100644 --- a/src/mlia/target/tosa/handlers.py +++ b/src/mlia/target/tosa/handlers.py @@ -7,7 +7,6 @@ from __future__ import annotations import logging from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo -from mlia.core.common import FormattedFilePath from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler from mlia.target.tosa.events import TOSAAdvisorEventHandler @@ -20,9 +19,9 @@ logger = logging.getLogger(__name__) class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): """Event handler for TOSA advisor.""" - def __init__(self, output: FormattedFilePath | None = None) -> None: + def __init__(self) -> None: """Init event handler.""" - super().__init__(tosa_formatters, output) + super().__init__(tosa_formatters) def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None: """Handle TOSAAdvisorStartedEvent event.""" diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py index 0659dcf..17f6cae 100644 --- a/src/mlia/utils/logging.py +++ b/src/mlia/utils/logging.py @@ -18,6 +18,8 @@ from typing import Generator from typing import Iterable from typing import TextIO +from mlia.utils.console import remove_ascii_codes + class LoggerWriter: """Redirect printed messages to the logger.""" @@ -152,12 +154,21 @@ class LogFilter(logging.Filter): return cls(skip_by_level) +class NoASCIIFormatter(logging.Formatter): + """Custom Formatter for logging into file.""" + + def format(self, record: logging.LogRecord) -> str: + """Overwrite format method to remove ascii codes from record.""" + result = super().format(record) + return remove_ascii_codes(result) + + def create_log_handler( *, file_path: Path | None = None, stream: Any | None = None, log_level: int | None = None, - log_format: str | None = None, + log_format: str | logging.Formatter | None = None, log_filter: logging.Filter | None = None, delay: bool = True, ) -> logging.Handler: @@ -176,7 +187,9 @@ def create_log_handler( handler.setLevel(log_level) if log_format: - handler.setFormatter(logging.Formatter(log_format)) + if isinstance(log_format, str): + log_format = logging.Formatter(log_format) + handler.setFormatter(log_format) if log_filter: handler.addFilter(log_filter) diff --git a/tests/test_api.py b/tests/test_api.py index 0bbc3ae..251d5ac 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -20,7 +20,11 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: """Test getting advice when no target provided.""" with pytest.raises(Exception, match="Target profile is not provided"): - get_advice(None, test_keras_model, {"compatibility"}) # type: ignore + get_advice( + None, # type:ignore + test_keras_model, + {"compatibility"}, + ) def test_get_advice_wrong_category(test_keras_model: Path) -> None: diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py index 5a80b4b..0b6eaf5 100644 --- a/tests/test_backend_tosa_compat.py +++ b/tests/test_backend_tosa_compat.py @@ -27,7 +27,7 @@ def replace_get_tosa_checker_with_mock( def test_compatibility_check_should_fail_if_checker_not_available( - monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path + monkeypatch: pytest.MonkeyPatch, test_tflite_model: str | Path ) -> None: """Test that compatibility check should fail if TOSA checker is not available.""" replace_get_tosa_checker_with_mock(monkeypatch, None) @@ -71,7 +71,7 @@ def test_compatibility_check_should_fail_if_checker_not_available( ) def test_get_tosa_compatibility_info( monkeypatch: pytest.MonkeyPatch, - test_tflite_model: Path, + test_tflite_model: str | Path, is_tosa_compatible: bool, operators: Any, exception: Exception | None, diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 5a9c0c9..9db5341 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse from functools import wraps -from pathlib import Path from typing import Any from typing import Callable from unittest.mock import ANY @@ -122,8 +121,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: model="sample_model.tflite", compatibility=False, performance=False, - output=None, - json=False, backend=None, ), ], @@ -135,8 +132,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: model="sample_model.tflite", compatibility=False, performance=False, - output=None, - json=False, backend=None, ), ], @@ -153,8 +148,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - output=None, - json=False, compatibility=True, performance=True, backend=None, @@ -167,9 +160,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: "--performance", "--target-profile", "ethos-u55-256", - "--output", - "result.json", - "--json", ], call( ctx=ANY, @@ -177,8 +167,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: model="sample_model.h5", performance=True, compatibility=False, - output=Path("result.json"), - json=True, backend=None, ), ], @@ -196,8 +184,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: model="sample_model.h5", compatibility=False, performance=True, - output=None, - json=False, backend=None, ), ], @@ -218,8 +204,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: clustering=True, pruning_target=None, clustering_target=None, - output=None, - json=False, backend=None, ), ], @@ -244,8 +228,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: clustering=True, pruning_target=0.5, clustering_target=32, - output=None, - json=False, backend=None, ), ], @@ -267,8 +249,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: clustering=False, pruning_target=None, clustering_target=None, - output=None, - json=False, backend=["some_backend"], ), ], @@ -286,8 +266,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: model="sample_model.h5", compatibility=True, performance=False, - output=None, - json=False, backend=None, ), ], diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py index a889a93..94c3111 100644 --- a/tests/test_cli_options.py +++ b/tests/test_cli_options.py @@ -5,16 +5,14 @@ from __future__ import annotations import argparse from contextlib import ExitStack as does_not_raise -from pathlib import Path from typing import Any import pytest -from mlia.cli.options import add_output_options +from mlia.cli.options import get_output_format from mlia.cli.options import get_target_profile_opts from mlia.cli.options import parse_optimization_parameters -from mlia.cli.options import parse_output_parameters -from mlia.core.common import FormattedFilePath +from mlia.core.typing import OutputFormat @pytest.mark.parametrize( @@ -164,54 +162,24 @@ def test_get_target_opts(args: dict | None, expected_opts: list[str]) -> None: @pytest.mark.parametrize( - "output_parameters, expected_path", - [ - [["--output", "report.json"], "report.json"], - [["--output", "REPORT.JSON"], "REPORT.JSON"], - [["--output", "some_folder/report.json"], "some_folder/report.json"], - ], -) -def test_output_options(output_parameters: list[str], expected_path: str) -> None: - """Test output options resolving.""" - parser = argparse.ArgumentParser() - add_output_options(parser) - - args = parser.parse_args(output_parameters) - assert str(args.output) == expected_path - - -@pytest.mark.parametrize( - "path, json, expected_error, output", + "args, expected_output_format", [ [ - None, - True, - pytest.raises( - argparse.ArgumentError, - match=r"To enable JSON output you need to specify the output path. " - r"\(e.g. --output out.json --json\)", - ), - None, + {}, + "plain_text", ], - [None, False, does_not_raise(), None], [ - Path("test_path"), - False, - does_not_raise(), - FormattedFilePath(Path("test_path"), "plain_text"), + {"json": True}, + "json", ], [ - Path("test_path"), - True, - does_not_raise(), - FormattedFilePath(Path("test_path"), "json"), + {"json": False}, + "plain_text", ], ], ) -def test_parse_output_parameters( - path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None -) -> None: - """Test parsing for output parameters.""" - with expected_error: - formatted_output = parse_output_parameters(path, json) - assert formatted_output == output +def test_get_output_format(args: dict, expected_output_format: OutputFormat) -> None: + """Test get_output_format function.""" + arguments = argparse.Namespace(**args) + output_format = get_output_format(arguments) + assert output_format == expected_output_format diff --git a/tests/test_core_context.py b/tests/test_core_context.py index dcdbef3..0e7145f 100644 --- a/tests/test_core_context.py +++ b/tests/test_core_context.py @@ -58,6 +58,7 @@ def test_execution_context(tmpdir: str) -> None: verbose=True, logs_dir="logs_directory", models_dir="models_directory", + output_format="json", ) assert context.advice_category == category @@ -68,12 +69,14 @@ def test_execution_context(tmpdir: str) -> None: expected_model_path = Path(tmpdir) / "models_directory/sample.model" assert context.get_model_path("sample.model") == expected_model_path assert context.verbose is True + assert context.output_format == "json" assert str(context) == ( f"ExecutionContext: " f"working_dir={tmpdir}, " "advice_category={'COMPATIBILITY'}, " "config_parameters={'param': 'value'}, " - "verbose=True" + "verbose=True, " + "output_format=json" ) context_with_default_params = ExecutionContext(working_dir=tmpdir) @@ -88,11 +91,13 @@ def test_execution_context(tmpdir: str) -> None: default_model_path = context_with_default_params.get_model_path("sample.model") expected_default_model_path = Path(tmpdir) / "models/sample.model" assert default_model_path == expected_default_model_path + assert context_with_default_params.output_format == "plain_text" expected_str = ( f"ExecutionContext: working_dir={tmpdir}, " "advice_category={'COMPATIBILITY'}, " "config_parameters=None, " - "verbose=False" + "verbose=False, " + "output_format=plain_text" ) assert str(context_with_default_params) == expected_str diff --git a/tests/test_cli_logging.py b/tests/test_core_logging.py index 1e2cc85..e021e26 100644 --- a/tests/test_cli_logging.py +++ b/tests/test_core_logging.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the module cli.logging.""" from __future__ import annotations @@ -8,7 +8,7 @@ from pathlib import Path import pytest -from mlia.cli.logging import setup_logging +from mlia.core.logging import setup_logging from tests.utils.logging import clear_loggers @@ -33,20 +33,21 @@ def teardown_function() -> None: ( None, True, - """mlia.backend.manager - backends debug -cli info -mlia.cli - cli debug + """mlia.backend.manager - DEBUG - backends debug +mlia.cli - INFO - cli info +mlia.cli - DEBUG - cli debug """, None, ), ( "logs", True, - """mlia.backend.manager - backends debug -cli info -mlia.cli - cli debug + """mlia.backend.manager - DEBUG - backends debug +mlia.cli - INFO - cli info +mlia.cli - DEBUG - cli debug """, """mlia.backend.manager - DEBUG - backends debug +mlia.cli - INFO - cli info mlia.cli - DEBUG - cli debug """, ), diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py index 71eaf85..7a68d4b 100644 --- a/tests/test_core_reporting.py +++ b/tests/test_core_reporting.py @@ -3,9 +3,13 @@ """Tests for reporting module.""" from __future__ import annotations -import io import json from enum import Enum +from unittest.mock import ANY +from unittest.mock import call +from unittest.mock import MagicMock +from unittest.mock import Mock +from unittest.mock import patch import numpy as np import pytest @@ -14,13 +18,15 @@ from mlia.core.reporting import BytesCell from mlia.core.reporting import Cell from mlia.core.reporting import ClockCell from mlia.core.reporting import Column +from mlia.core.reporting import CustomJSONEncoder from mlia.core.reporting import CyclesCell from mlia.core.reporting import Format -from mlia.core.reporting import json_reporter +from mlia.core.reporting import JSONReporter from mlia.core.reporting import NestedReport from mlia.core.reporting import ReportItem from mlia.core.reporting import SingleRow from mlia.core.reporting import Table +from mlia.core.reporting import TextReporter from mlia.utils.console import remove_ascii_codes @@ -364,10 +370,9 @@ def test_custom_json_serialization() -> None: alias="sample_table", ) - output = io.StringIO() - json_reporter(table, output) + output = json.dumps(table.to_json(), indent=4, cls=CustomJSONEncoder) - assert json.loads(output.getvalue()) == { + assert json.loads(output) == { "sample_table": [ {"column1": "value1"}, {"column1": 10.0}, @@ -375,3 +380,93 @@ def test_custom_json_serialization() -> None: {"column1": 10}, ] } + + +class TestTextReporter: + """Test TextReporter methods.""" + + def test_text_reporter(self) -> None: + """Test TextReporter.""" + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + assert reporter.output_format == "plain_text" + + def test_submit(self) -> None: + """Test TextReporter submit.""" + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + reporter.submit("test") + assert reporter.data == [("test", ANY)] + + reporter.submit("test2", delay_print=True) + assert reporter.data == [("test", ANY), ("test2", ANY)] + assert reporter.delayed == [("test2", ANY)] + + def test_print_delayed(self) -> None: + """Test TextReporter print_delayed.""" + with patch( + "mlia.core.reporting.TextReporter.produce_report" + ) as mock_produce_report: + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + reporter.submit("test", delay_print=True) + reporter.print_delayed() + assert reporter.data == [("test", ANY)] + assert not reporter.delayed + mock_produce_report.assert_called() + + def test_produce_report(self) -> None: + """Test TextReporter produce_report.""" + format_resolver = MagicMock() + reporter = TextReporter(format_resolver) + + with patch("mlia.core.reporting.logger") as mock_logger: + mock_formatter = MagicMock() + reporter.produce_report("test", mock_formatter) + mock_formatter.assert_has_calls([call("test"), call().to_plain_text()]) + mock_logger.info.assert_called() + + +class TestJSONReporter: + """Test JSONReporter methods.""" + + def test_text_reporter(self) -> None: + """Test JSONReporter.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + assert reporter.output_format == "json" + + def test_submit(self) -> None: + """Test JSONReporter submit.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + reporter.submit("test") + assert reporter.data == [("test", ANY)] + + reporter.submit("test2") + assert reporter.data == [("test", ANY), ("test2", ANY)] + + def test_generate_report(self) -> None: + """Test JSONReporter generate_report.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + reporter.submit("test") + + with patch( + "mlia.core.reporting.JSONReporter.produce_report" + ) as mock_produce_report: + reporter.generate_report() + mock_produce_report.assert_called() + + @patch("builtins.print") + def test_produce_report(self, mock_print: Mock) -> None: + """Test JSONReporter produce_report.""" + format_resolver = MagicMock() + reporter = JSONReporter(format_resolver) + + with patch("json.dumps") as mock_dumps: + mock_formatter = MagicMock() + reporter.produce_report("test", mock_formatter) + mock_formatter.assert_has_calls([call("test"), call().to_json()]) + mock_dumps.assert_called() + mock_print.assert_called() diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py index 7f372bf..ee7ea52 100644 --- a/tests/test_target_ethos_u_reporters.py +++ b/tests/test_target_ethos_u_reporters.py @@ -1,106 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for reports module.""" from __future__ import annotations -import json -import sys -from contextlib import ExitStack as doesnt_raise -from pathlib import Path -from typing import Any -from typing import Callable -from typing import Literal - import pytest from mlia.backend.vela.compat import NpuSupported from mlia.backend.vela.compat import Operator -from mlia.backend.vela.compat import Operators -from mlia.core.reporting import get_reporter -from mlia.core.reporting import produce_report from mlia.core.reporting import Report -from mlia.core.reporting import Reporter from mlia.core.reporting import Table from mlia.target.ethos_u.config import EthosUConfiguration -from mlia.target.ethos_u.performance import MemoryUsage -from mlia.target.ethos_u.performance import NPUCycles -from mlia.target.ethos_u.performance import PerformanceMetrics -from mlia.target.ethos_u.reporters import ethos_u_formatters from mlia.target.ethos_u.reporters import report_device_details from mlia.target.ethos_u.reporters import report_operators -from mlia.target.ethos_u.reporters import report_perf_metrics from mlia.utils.console import remove_ascii_codes @pytest.mark.parametrize( - "data, formatters", - [ - ( - [Operator("test_operator", "test_type", NpuSupported(False, []))], - [report_operators], - ), - ( - PerformanceMetrics( - EthosUConfiguration("ethos-u55-256"), - NPUCycles(0, 0, 0, 0, 0, 0), - MemoryUsage(0, 0, 0, 0, 0), - ), - [report_perf_metrics], - ), - ], -) -@pytest.mark.parametrize( - "fmt, output, expected_error", - [ - [ - "unknown_format", - sys.stdout, - pytest.raises(Exception, match="Unknown format unknown_format"), - ], - [ - "plain_text", - sys.stdout, - doesnt_raise(), - ], - [ - "json", - sys.stdout, - doesnt_raise(), - ], - [ - "plain_text", - "report.txt", - doesnt_raise(), - ], - [ - "json", - "report.json", - doesnt_raise(), - ], - ], -) -def test_report( - data: Any, - formatters: list[Callable], - fmt: Literal["plain_text", "json"], - output: Any, - expected_error: Any, - tmp_path: Path, -) -> None: - """Test report function.""" - if is_file := isinstance(output, str): - output = tmp_path / output - - for formatter in formatters: - with expected_error: - produce_report(data, formatter, fmt, output) - - if is_file: - assert output.is_file() - assert output.stat().st_size > 0 - - -@pytest.mark.parametrize( "ops, expected_plain_text, expected_json_dict", [ ( @@ -314,40 +229,3 @@ def test_report_device_details( json_dict = report.to_json() assert json_dict == expected_json_dict - - -def test_get_reporter(tmp_path: Path) -> None: - """Test reporter functionality.""" - ops = Operators( - [ - Operator( - "npu_supported", - "op_type", - NpuSupported(True, []), - ), - ] - ) - - output = tmp_path / "output.json" - with get_reporter("json", output, ethos_u_formatters) as reporter: - assert isinstance(reporter, Reporter) - - with pytest.raises( - Exception, match="Unable to find appropriate formatter for some_data" - ): - reporter.submit("some_data") - - reporter.submit(ops) - - with open(output, encoding="utf-8") as file: - json_data = json.load(file) - - assert json_data == { - "operators_stats": [ - { - "npu_unsupported_ratio": 0.0, - "num_of_npu_supported_operators": 1, - "num_of_operators": 1, - } - ] - } diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index 35fd707..beddaed 100644 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -16,6 +16,7 @@ from pathlib import Path from typing import Any from typing import Generator from typing import Iterable +from typing import Sequence import pytest @@ -75,30 +76,48 @@ class ExecutionConfiguration: ) -def launch_and_wait(cmd: list[str], stdin: Any | None = None) -> None: +def launch_and_wait( + cmd: list[str], + output_file: Path | None = None, + print_output: bool = True, + stdin: Any | None = None, +) -> None: """Launch command and wait for the completion.""" with subprocess.Popen( # nosec cmd, stdin=stdin, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, # redirect command stderr to stdout + stderr=subprocess.PIPE, universal_newlines=True, bufsize=1, ) as process: - if process.stdout is None: + # Process stdout + if process.stdout: + # Store the output in a variable + output = process.stdout.read() + # Save the output into a file + if output_file: + output_file.write_text(output) + print(f"Output saved to {output_file}") + # Show the output to stdout + if print_output: + print(output) + else: raise Exception("Unable to get process output") - # redirect output of the process into current process stdout - for line in process.stdout: - print(line, end="") - + # Wait for the process to terminate process.wait() if (exit_code := process.poll()) != 0: raise Exception(f"Command failed with exit_code {exit_code}") -def run_command(cmd: list[str], cmd_input: str | None = None) -> None: +def run_command( + cmd: list[str], + output_file: Path | None = None, + print_output: bool = True, + cmd_input: str | None = None, +) -> None: """Run command.""" print(f"Run command: {' '.join(cmd)}") @@ -118,7 +137,7 @@ def run_command(cmd: list[str], cmd_input: str | None = None) -> None: cmd_input_file.write(cmd_input) cmd_input_file.seek(0) - launch_and_wait(cmd, cmd_input_file) + launch_and_wait(cmd, output_file, print_output, cmd_input_file) def get_config_file() -> Path: @@ -209,10 +228,15 @@ def get_config_content(config_file: Path) -> Any: executions = json_data.get("executions", []) assert is_list_of(executions, dict), "List of the dictionaries expected" - return executions + settings = json_data.get("settings", {}) + assert isinstance(settings, dict) + + return settings, executions -def get_all_commands_combinations(executions: Any) -> Generator[list[str], None, None]: +def get_all_commands_combinations( + executions: Any, +) -> Generator[dict[str, Sequence[str]], None, None]: """Return all commands combinations.""" exec_configs = ( ExecutionConfiguration.from_dict(exec_info) for exec_info in executions @@ -221,13 +245,12 @@ def get_all_commands_combinations(executions: Any) -> Generator[list[str], None, parser = get_args_parser() for exec_config in exec_configs: for command_combination in exec_config.all_combinations: - for idx, param in enumerate(command_combination): - if "{model_name}" in param: - args = parser.parse_args(command_combination) - model_name = Path(args.model).stem - param = param.replace("{model_name}", model_name) - command_combination[idx] = param - yield command_combination + args = parser.parse_args(command_combination) + model_name = Path(args.model).stem + yield { + "model_name": model_name, + "command_combination": command_combination, + } def check_args(args: list[str], no_skip: bool) -> None: @@ -249,21 +272,31 @@ def check_args(args: list[str], no_skip: bool) -> None: pytest.skip(f"Missing backend(s): {','.join(missing_backends)}") -def get_execution_definitions() -> Generator[list[str], None, None]: +def get_execution_definitions( + executions: dict, +) -> Generator[dict[str, Sequence[str]], None, None]: """Collect all execution definitions from configuration file.""" - config_file = get_config_file() - executions = get_config_content(config_file) - executions = resolve_parameters(executions) - - return get_all_commands_combinations(executions) + resolved_executions = resolve_parameters(executions) + return get_all_commands_combinations(resolved_executions) class TestEndToEnd: """End to end command tests.""" - @pytest.mark.parametrize("command", get_execution_definitions(), ids=str) - def test_e2e(self, command: list[str], no_skip: bool) -> None: + configuration_file = get_config_file() + settings, executions = get_config_content(configuration_file) + + @pytest.mark.parametrize( + "command_data", get_execution_definitions(executions), ids=str + ) + def test_e2e(self, command_data: dict[str, list[str]], no_skip: bool) -> None: """Test MLIA command with the provided parameters.""" + command = command_data["command_combination"] + model_name = command_data["model_name"] check_args(command, no_skip) mlia_command = ["mlia", *command] - run_command(mlia_command) + print_output = self.settings.get("print_output", True) + output_file = self.settings.get("output_file", None) + if output_file: + output_file = Path(output_file.replace("{model_name}", model_name)) + run_command(mlia_command, output_file, print_output) |