aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiego Russo <diego.russo@arm.com>2023-02-02 22:04:05 +0000
committerDiego Russo <diego.russo@arm.com>2023-02-03 17:48:11 +0000
commitf1eaff3c9790464bed3183ff76555cf815166f50 (patch)
treebba288bb051925a692e1e1f998f7cc48df755df6
parentd1b2374cda6811a93d1174400fc2eecd7100a8c3 (diff)
downloadmlia-f1eaff3c9790464bed3183ff76555cf815166f50.tar.gz
MLIA-782 Remove --output parameter
* Remove --output parameter from argument parser * Remove FormattedFilePath class and its presence across the codebase * Move logging module from cli to core * The output format is now injected in the execution context and used across MLIA * Depending on the output format, TextReporter and JSONReporter have been created and used accordingly. * The whole output to standard output and/or logfile is driven via the logging module: the only case where the print is used is when the --json parameter is specified. This is needed becase all output (including third party application as well) needs to be disabled otherwise it might corrupt the json output in the standard output. * Debug information is logged into the log file and printed to stdout when the output format is plain_text. * Update E2E test and config to cope with the new mechanism of outputting json data to standard output. Change-Id: I4395800b0b1af4d24406a828d780bdeef98cd413
-rw-r--r--src/mlia/api.py13
-rw-r--r--src/mlia/backend/tosa_checker/compat.py6
-rw-r--r--src/mlia/cli/commands.py24
-rw-r--r--src/mlia/cli/main.py9
-rw-r--r--src/mlia/cli/options.py35
-rw-r--r--src/mlia/core/common.py33
-rw-r--r--src/mlia/core/context.py16
-rw-r--r--src/mlia/core/handlers.py31
-rw-r--r--src/mlia/core/logging.py (renamed from src/mlia/cli/logging.py)61
-rw-r--r--src/mlia/core/reporting.py166
-rw-r--r--src/mlia/core/typing.py7
-rw-r--r--src/mlia/core/workflow.py3
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/cortex_a/handlers.py5
-rw-r--r--src/mlia/target/ethos_u/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/handlers.py5
-rw-r--r--src/mlia/target/tosa/advisor.py4
-rw-r--r--src/mlia/target/tosa/handlers.py5
-rw-r--r--src/mlia/utils/logging.py17
-rw-r--r--tests/test_api.py6
-rw-r--r--tests/test_backend_tosa_compat.py4
-rw-r--r--tests/test_cli_main.py22
-rw-r--r--tests/test_cli_options.py60
-rw-r--r--tests/test_core_context.py9
-rw-r--r--tests/test_core_logging.py (renamed from tests/test_cli_logging.py)17
-rw-r--r--tests/test_core_reporting.py105
-rw-r--r--tests/test_target_ethos_u_reporters.py124
-rw-r--r--tests_e2e/test_e2e.py87
28 files changed, 394 insertions, 488 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 2cabf37..8105276 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -9,7 +9,6 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import ExecutionContext
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
@@ -24,7 +23,6 @@ def get_advice(
model: str | Path,
category: set[str],
optimization_targets: list[dict[str, Any]] | None = None,
- output: FormattedFilePath | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
) -> None:
@@ -42,8 +40,6 @@ def get_advice(
category "compatibility" is used by default.
:param optimization_targets: optional model optimization targets that
could be used for generating advice in "optimization" category.
- :param output: path to the report file. If provided, MLIA will save
- report in this location.
:param context: optional parameter which represents execution context,
could be used for advanced use cases
:param backends: A list of backends that should be used for the given
@@ -57,11 +53,9 @@ def get_advice(
>>> get_advice("ethos-u55-256", "path/to/the/model",
{"optimization", "compatibility"})
- Getting the advice for the category "performance" and save result report in file
- "report.json"
+ Getting the advice for the category "performance".
- >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"},
- output=FormattedFilePath("report.json")
+ >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"})
"""
advice_category = AdviceCategory.from_string(category)
@@ -76,7 +70,6 @@ def get_advice(
context,
target_profile,
model,
- output,
optimization_targets=optimization_targets,
backends=backends,
)
@@ -88,7 +81,6 @@ def get_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
@@ -109,6 +101,5 @@ def get_advisor(
context,
target_profile,
model,
- output,
**extra_args,
)
diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py
index 81f3015..1c410d3 100644
--- a/src/mlia/backend/tosa_checker/compat.py
+++ b/src/mlia/backend/tosa_checker/compat.py
@@ -5,12 +5,12 @@ from __future__ import annotations
import sys
from dataclasses import dataclass
+from pathlib import Path
from typing import Any
from typing import cast
from typing import Protocol
from mlia.backend.errors import BackendUnavailableError
-from mlia.core.typing import PathOrFileLike
from mlia.utils.logging import capture_raw_output
@@ -45,7 +45,7 @@ class TOSACompatibilityInfo:
def get_tosa_compatibility_info(
- tflite_model_path: PathOrFileLike,
+ tflite_model_path: str | Path,
) -> TOSACompatibilityInfo:
"""Return list of the operators."""
# Capture the possible exception in running get_tosa_checker
@@ -100,7 +100,7 @@ def get_tosa_compatibility_info(
)
-def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
+def get_tosa_checker(tflite_model_path: str | Path) -> TOSAChecker | None:
"""Return instance of the TOSA checker."""
try:
import tosa_checker as tc # pylint: disable=import-outside-toplevel
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index d2242ba..c17d571 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -7,10 +7,10 @@ functionality.
Before running them from scripts 'logging' module should
be configured. Function 'setup_logging' from module
-'mli.cli.logging' could be used for that, e.g.
+'mli.core.logging' could be used for that, e.g.
>>> from mlia.api import ExecutionContext
->>> from mlia.cli.logging import setup_logging
+>>> from mlia.core.logging import setup_logging
>>> setup_logging(verbose=True)
>>> import mlia.cli.commands as mlia
>>> mlia.check(ExecutionContext(), "ethos-u55-256",
@@ -27,7 +27,6 @@ from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
-from mlia.cli.options import parse_output_parameters
from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
@@ -41,8 +40,6 @@ def check(
model: str | None = None,
compatibility: bool = False,
performance: bool = False,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -61,7 +58,6 @@ def check(
:param model: path to the Keras model
:param compatibility: flag that identifies whether to run compatibility checks
:param performance: flag that identifies whether to run performance checks
- :param output: path to the file where the report will be saved
:param backend: list of the backends to use for evaluation
Example:
@@ -69,18 +65,15 @@ def check(
and operator compatibility.
>>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> setup_logging()
>>> from mlia.cli.commands import check
>>> check(ExecutionContext(), "ethos-u55-256",
- "model.h5", compatibility=True, performance=True,
- output="report.json")
+ "model.h5", compatibility=True, performance=True)
"""
if not model:
raise Exception("Model is not provided")
- formatted_output = parse_output_parameters(output, json)
-
# Set category based on checks to perform (i.e. "compatibility" and/or
# "performance").
# If no check type is specified, "compatibility" is the default category.
@@ -98,7 +91,6 @@ def check(
target_profile,
model,
category,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
@@ -113,8 +105,6 @@ def optimize( # pylint: disable=too-many-arguments
pruning_target: float | None,
clustering_target: int | None,
layers_to_optimize: list[str] | None = None,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -133,15 +123,13 @@ def optimize( # pylint: disable=too-many-arguments
:param pruning_target: pruning optimization target
:param layers_to_optimize: list of the layers of the model which should be
optimized, if None then all layers are used
- :param output: path to the file where the report will be saved
- :param json: set the output format to json
:param backend: list of the backends to use for evaluation
Example:
Run command for the target profile ethos-u55-256 and
the provided TensorFlow Lite model and print report on the standard output
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> from mlia.api import ExecutionContext
>>> setup_logging()
>>> from mlia.cli.commands import optimize
@@ -161,7 +149,6 @@ def optimize( # pylint: disable=too-many-arguments
)
)
- formatted_output = parse_output_parameters(output, json)
validated_backend = validate_backend(target_profile, backend)
get_advice(
@@ -169,7 +156,6 @@ def optimize( # pylint: disable=too-many-arguments
model,
{"optimization"},
optimization_targets=opt_params,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 1102d45..7ce7dc9 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -19,7 +19,6 @@ from mlia.cli.commands import check
from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
-from mlia.cli.logging import setup_logging
from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
@@ -30,9 +29,11 @@ from mlia.cli.options import add_model_options
from mlia.cli.options import add_multi_optimization_options
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
+from mlia.cli.options import get_output_format
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
+from mlia.core.logging import setup_logging
from mlia.target.registry import registry as target_registry
@@ -162,10 +163,13 @@ def setup_context(
ctx = ExecutionContext(
verbose="debug" in args and args.debug,
action_resolver=CLIActionResolver(vars(args)),
+ output_format=get_output_format(args),
)
+ setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format)
+
# these parameters should not be passed into command function
- skipped_params = ["func", "command", "debug"]
+ skipped_params = ["func", "command", "debug", "json"]
# pass these parameters only if command expects them
expected_params = [context_var_name]
@@ -186,7 +190,6 @@ def setup_context(
def run_command(args: argparse.Namespace) -> int:
"""Run command."""
ctx, func_args = setup_context(args)
- setup_logging(ctx.logs_path, ctx.verbose)
logger.debug(
"*** This is the beginning of the command '%s' execution ***", args.command
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 5aca3b3..e01f107 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,7 +12,7 @@ from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
from mlia.cli.config import DEFAULT_PRUNING_TARGET
from mlia.cli.config import get_available_backends
from mlia.cli.config import is_corstone_backend
-from mlia.core.common import FormattedFilePath
+from mlia.core.typing import OutputFormat
from mlia.utils.filesystem import get_supported_profile_names
@@ -89,16 +89,9 @@ def add_output_options(parser: argparse.ArgumentParser) -> None:
"""Add output specific options."""
output_group = parser.add_argument_group("output options")
output_group.add_argument(
- "-o",
- "--output",
- type=Path,
- help=("Name of the file where the report will be saved."),
- )
-
- output_group.add_argument(
"--json",
action="store_true",
- help=("Format to use for the output (requires --output argument to be set)."),
+ help=("Print the output in JSON format."),
)
@@ -209,22 +202,6 @@ def add_backend_options(
)
-def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None:
- """Parse and return path and file format as FormattedFilePath."""
- if not path and json:
- raise argparse.ArgumentError(
- None,
- "To enable JSON output you need to specify the output path. "
- "(e.g. --output out.json --json)",
- )
- if not path:
- return None
- if json:
- return FormattedFilePath(path, "json")
-
- return FormattedFilePath(path, "plain_text")
-
-
def parse_optimization_parameters(
pruning: bool = False,
clustering: bool = False,
@@ -301,3 +278,11 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]:
for name in non_default
for item in construct_param(params_name[name], device_args[name])
]
+
+
+def get_output_format(args: argparse.Namespace) -> OutputFormat:
+ """Return the OutputFormat depending on the CLI flags."""
+ output_format: OutputFormat = "plain_text"
+ if "json" in args and args.json:
+ output_format = "json"
+ return output_format
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index 53df001..baaed50 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -13,9 +13,6 @@ from enum import auto
from enum import Flag
from typing import Any
-from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
-
# This type is used as type alias for the items which are being passed around
# in advisor workflow. There are no restrictions on the type of the
# object. This alias used only to emphasize the nature of the input/output
@@ -23,36 +20,6 @@ from mlia.core.typing import PathOrFileLike
DataItem = Any
-class FormattedFilePath:
- """Class used to keep track of the format that a path points to."""
-
- def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None:
- """Init FormattedFilePath."""
- self._path = path
- self._fmt = fmt
-
- @property
- def fmt(self) -> OutputFormat:
- """Return file format."""
- return self._fmt
-
- @property
- def path(self) -> PathOrFileLike:
- """Return file path."""
- return self._path
-
- def __eq__(self, other: object) -> bool:
- """Check for equality with other objects."""
- if isinstance(other, FormattedFilePath):
- return other.fmt == self.fmt and other.path == self.path
-
- return False
-
- def __repr__(self) -> str:
- """Represent object."""
- return f"FormattedFilePath {self.path=}, {self.fmt=}"
-
-
class AdviceCategory(Flag):
"""Advice category.
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index 94aa885..f8442a3 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -23,6 +23,7 @@ from mlia.core.events import EventHandler
from mlia.core.events import EventPublisher
from mlia.core.helpers import ActionResolver
from mlia.core.helpers import APIActionResolver
+from mlia.core.typing import OutputFormat
logger = logging.getLogger(__name__)
@@ -68,6 +69,11 @@ class Context(ABC):
def action_resolver(self) -> ActionResolver:
"""Return action resolver."""
+ @property
+ @abstractmethod
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+
@abstractmethod
def update(
self,
@@ -106,6 +112,7 @@ class ExecutionContext(Context):
logs_dir: str = "logs",
models_dir: str = "models",
action_resolver: ActionResolver | None = None,
+ output_format: OutputFormat = "plain_text",
) -> None:
"""Init execution context.
@@ -139,6 +146,7 @@ class ExecutionContext(Context):
self.logs_dir = logs_dir
self.models_dir = models_dir
self._action_resolver = action_resolver or APIActionResolver()
+ self._output_format = output_format
@property
def working_dir(self) -> Path:
@@ -197,6 +205,11 @@ class ExecutionContext(Context):
"""Return path to the logs directory."""
return self._working_dir_path / self.logs_dir
+ @property
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+ return self._output_format
+
def update(
self,
*,
@@ -221,7 +234,8 @@ class ExecutionContext(Context):
f"ExecutionContext: working_dir={self._working_dir_path}, "
f"advice_category={category}, "
f"config_parameters={self.config_parameters}, "
- f"verbose={self.verbose}"
+ f"verbose={self.verbose}, "
+ f"output_format={self.output_format}"
)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index 6e50934..24d4881 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -9,7 +9,6 @@ from typing import Callable
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
-from mlia.core.common import FormattedFilePath
from mlia.core.events import ActionFinishedEvent
from mlia.core.events import ActionStartedEvent
from mlia.core.events import AdviceStageFinishedEvent
@@ -25,9 +24,11 @@ from mlia.core.events import EventDispatcher
from mlia.core.events import ExecutionFailedEvent
from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
+from mlia.core.mixins import ContextMixin
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
-from mlia.core.typing import PathOrFileLike
+from mlia.core.reporting import TextReporter
from mlia.utils.console import create_section_header
@@ -92,26 +93,27 @@ _ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
_ADV_GENERATION_MSG = create_section_header("Advice Generation")
-_REPORT_GENERATION_MSG = create_section_header("Report Generation")
-class WorkflowEventsHandler(SystemEventsHandler):
+class WorkflowEventsHandler(SystemEventsHandler, ContextMixin):
"""Event handler for the system events."""
+ reporter: Reporter
+
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: FormattedFilePath | None = None,
) -> None:
"""Init event handler."""
- output_format = output.fmt if output else "plain_text"
- self.reporter = Reporter(formatter_resolver, output_format)
- self.output = output.path if output else None
-
+ self.formatter_resolver = formatter_resolver
self.advice: list[Advice] = []
def on_execution_started(self, event: ExecutionStartedEvent) -> None:
"""Handle ExecutionStarted event."""
+ if self.context.output_format == "json":
+ self.reporter = JSONReporter(self.formatter_resolver)
+ else:
+ self.reporter = TextReporter(self.formatter_resolver)
logger.info(_ADV_EXECUTION_STARTED)
def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
@@ -132,12 +134,6 @@ class WorkflowEventsHandler(SystemEventsHandler):
"""Handle DataCollectorSkipped event."""
logger.info("Skipped: %s", event.reason)
- @staticmethod
- def report_generated(output: PathOrFileLike) -> None:
- """Log report generation."""
- logger.info(_REPORT_GENERATION_MSG)
- logger.info("Report(s) and advice list saved to: %s", output)
-
def on_data_analysis_stage_finished(
self, event: DataAnalysisStageFinishedEvent
) -> None:
@@ -160,7 +156,4 @@ class WorkflowEventsHandler(SystemEventsHandler):
table_style="no_borders",
)
- self.reporter.generate_report(self.output)
-
- if self.output is not None:
- self.report_generated(self.output)
+ self.reporter.generate_report()
diff --git a/src/mlia/cli/logging.py b/src/mlia/core/logging.py
index 5c5c4b8..686f8ab 100644
--- a/src/mlia/cli/logging.py
+++ b/src/mlia/core/logging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI logging configuration."""
from __future__ import annotations
@@ -8,18 +8,20 @@ import sys
from pathlib import Path
from typing import Iterable
+from mlia.core.typing import OutputFormat
from mlia.utils.logging import attach_handlers
from mlia.utils.logging import create_log_handler
-from mlia.utils.logging import LogFilter
+from mlia.utils.logging import NoASCIIFormatter
-_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s"
+_CONSOLE_DEBUG_FORMAT = "%(name)s - %(levelname)s - %(message)s"
_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def setup_logging(
logs_dir: str | Path | None = None,
verbose: bool = False,
+ output_format: OutputFormat = "plain_text",
log_filename: str = "mlia.log",
) -> None:
"""Set up logging.
@@ -30,6 +32,8 @@ def setup_logging(
debug information. If the path is not provided then no log files will
be created during execution
:param verbose: enable extended logging for the tools loggers
+ :param output_format: specify the out format needed for setting up the right
+ logging system
:param log_filename: name of the log file in the logs directory
"""
mlia_logger, tensorflow_logger, py_warnings_logger = (
@@ -42,7 +46,7 @@ def setup_logging(
for logger in [mlia_logger, tensorflow_logger]:
logger.setLevel(logging.DEBUG)
- mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
+ mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose, output_format)
attach_handlers(mlia_handlers, [mlia_logger])
tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
@@ -53,28 +57,47 @@ def _get_mlia_handlers(
logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
+ output_format: OutputFormat,
) -> Iterable[logging.Handler]:
"""Get handlers for the MLIA loggers."""
- yield create_log_handler(
- stream=sys.stdout,
- log_level=logging.INFO,
- )
-
- if verbose:
- mlia_verbose_handler = create_log_handler(
- stream=sys.stdout,
- log_level=logging.DEBUG,
- log_format=_CONSOLE_DEBUG_FORMAT,
- log_filter=LogFilter.equals(logging.DEBUG),
+ # MLIA needs output to standard output via the logging system only when the
+ # format is plain text. When the user specifies the "json" output format,
+ # MLIA disables completely the logging system for the console output and it
+ # relies on the print() function. This is needed because the output might
+ # be corrupted with spurious messages in the standard output.
+ if output_format == "plain_text":
+ if verbose:
+ log_level = logging.DEBUG
+ log_format = _CONSOLE_DEBUG_FORMAT
+ else:
+ log_level = logging.INFO
+ log_format = None
+
+ # Create log handler for stdout
+ yield create_log_handler(
+ stream=sys.stdout, log_level=log_level, log_format=log_format
+ )
+ else:
+ # In case of non plain text output, we need to inform the user if an
+ # error happens during execution.
+ yield create_log_handler(
+ stream=sys.stderr,
+ log_level=logging.ERROR,
)
- yield mlia_verbose_handler
+ # If the logs directory is specified, MLIA stores all output (according to
+ # the logging level) into the file and removing the colouring of the
+ # console output.
if logs_dir:
+ if verbose:
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+
yield create_log_handler(
file_path=_get_log_file(logs_dir, log_filename),
- log_level=logging.DEBUG,
- log_format=_FILE_DEBUG_FORMAT,
- log_filter=LogFilter.skip(logging.INFO),
+ log_level=log_level,
+ log_format=NoASCIIFormatter(fmt=_FILE_DEBUG_FORMAT),
delay=True,
)
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index ad63d62..7b9ce5c 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -8,30 +8,21 @@ import logging
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
-from contextlib import contextmanager
-from contextlib import ExitStack
from dataclasses import dataclass
from enum import Enum
from functools import partial
-from io import TextIOWrapper
-from pathlib import Path
from textwrap import fill
from textwrap import indent
from typing import Any
from typing import Callable
-from typing import cast
from typing import Collection
-from typing import Generator
from typing import Iterable
import numpy as np
-from mlia.core.typing import FileLike
from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
from mlia.utils.console import apply_style
from mlia.utils.console import produce_table
-from mlia.utils.logging import LoggerWriter
from mlia.utils.types import is_list_of
logger = logging.getLogger(__name__)
@@ -505,76 +496,48 @@ class CustomJSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o)
-def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in json format."""
- json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder)
- print(json_str, file=output)
-
-
-def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in text format."""
- print(report.to_plain_text(**kwargs), file=output)
+class Reporter(ABC):
+ """Reporter class."""
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ self.formatter_resolver = formatter_resolver
+ self.data: list[tuple[Any, Callable[[Any], Report]]] = []
-def produce_report(
- data: Any,
- formatter: Callable[[Any], Report],
- fmt: OutputFormat = "plain_text",
- output: PathOrFileLike | None = None,
- **kwargs: Any,
-) -> None:
- """Produce report based on provided data."""
- # check if provided format value is supported
- formats = {"json": json_reporter, "plain_text": text_reporter}
- if fmt not in formats:
- raise Exception(f"Unknown format {fmt}")
+ @abstractmethod
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
- if output is None:
- output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO))
+ def print_delayed(self) -> None:
+ """Print delayed reports."""
- with ExitStack() as exit_stack:
- if isinstance(output, (str, Path)):
- # open file and add it to the ExitStack context manager
- # in that case it will be automatically closed
- stream = exit_stack.enter_context(open(output, "w", encoding="utf-8"))
- else:
- stream = cast(TextIOWrapper, output)
+ def generate_report(self) -> None:
+ """Generate report."""
- # convert data into serializable form
- formatted_data = formatter(data)
- # find handler for the format
- format_handler = formats[fmt]
- # produce report in requested format
- format_handler(formatted_data, stream, **kwargs)
+ @abstractmethod
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
-class Reporter:
+class TextReporter(Reporter):
"""Reporter class."""
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output_format: OutputFormat = "plain_text",
- print_as_submitted: bool = True,
) -> None:
"""Init reporter instance."""
- self.formatter_resolver = formatter_resolver
- self.output_format = output_format
- self.print_as_submitted = print_as_submitted
-
- self.data: list[tuple[Any, Callable[[Any], Report]]] = []
+ super().__init__(formatter_resolver)
self.delayed: list[tuple[Any, Callable[[Any], Report]]] = []
+ self.output_format: OutputFormat = "plain_text"
def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
"""Submit data for the report."""
- if self.print_as_submitted and not delay_print:
- produce_report(
- data_item,
- self.formatter_resolver(data_item),
- fmt="plain_text",
- **kwargs,
- )
-
formatter = _apply_format_parameters(
self.formatter_resolver(data_item), self.output_format, **kwargs
)
@@ -582,51 +545,70 @@ class Reporter:
if delay_print:
self.delayed.append((data_item, formatter))
+ else:
+ self.produce_report(
+ data_item,
+ self.formatter_resolver(data_item),
+ **kwargs,
+ )
def print_delayed(self) -> None:
"""Print delayed reports."""
- if not self.delayed:
- return
+ if self.delayed:
+ data, formatters = zip(*self.delayed)
+ self.produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ )
+ self.delayed = []
- data, formatters = zip(*self.delayed)
- produce_report(
- data,
- formatter=CompoundFormatter(formatters),
- fmt="plain_text",
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ logger.info(formatted_data.to_plain_text(**kwargs))
+
+
+class JSONReporter(Reporter):
+ """Reporter class."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ super().__init__(formatter_resolver)
+ self.output_format: OutputFormat = "json"
+
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
+ formatter = _apply_format_parameters(
+ self.formatter_resolver(data_item), self.output_format, **kwargs
)
- self.delayed = []
+ self.data.append((data_item, formatter))
- def generate_report(self, output: PathOrFileLike | None) -> None:
+ def generate_report(self) -> None:
"""Generate report."""
- already_printed = (
- self.print_as_submitted
- and self.output_format == "plain_text"
- and output is None
- )
- if not self.data or already_printed:
+ if not self.data:
return
data, formatters = zip(*self.data)
- produce_report(
+ self.produce_report(
data,
formatter=CompoundFormatter(formatters),
- fmt=self.output_format,
- output=output,
)
-
-@contextmanager
-def get_reporter(
- output_format: OutputFormat,
- output: PathOrFileLike | None,
- formatter_resolver: Callable[[Any], Callable[[Any], Report]],
-) -> Generator[Reporter, None, None]:
- """Get reporter and generate report."""
- reporter = Reporter(formatter_resolver, output_format)
-
- yield reporter
-
- reporter.generate_report(output)
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ print(
+ json.dumps(
+ formatted_data.to_json(**kwargs), indent=4, cls=CustomJSONEncoder
+ ),
+ )
def _apply_format_parameters(
diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py
index 8244ff5..ea334c9 100644
--- a/src/mlia/core/typing.py
+++ b/src/mlia/core/typing.py
@@ -1,12 +1,7 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for custom type hints."""
-from pathlib import Path
from typing import Literal
-from typing import TextIO
-from typing import Union
-FileLike = TextIO
-PathOrFileLike = Union[str, Path, FileLike]
OutputFormat = Literal["plain_text", "json"]
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index d862a86..9f8ac83 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for executors.
@@ -198,6 +198,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
self.collectors,
self.analyzers,
self.producers,
+ self.context.event_handlers or [],
)
if isinstance(comp, ContextMixin)
)
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index b649f0d..1249d93 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -67,12 +66,11 @@ def configure_and_get_cortexa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
if context.event_handlers is None:
- context.event_handlers = [CortexAEventHandler(output)]
+ context.event_handlers = [CortexAEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(model, target_profile)
diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py
index d6acde5..1a74da7 100644
--- a/src/mlia/target/cortex_a/handlers.py
+++ b/src/mlia/target/cortex_a/handlers.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import logging
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
@@ -20,9 +19,9 @@ logger = logging.getLogger(__name__)
class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(cortex_a_formatters, output)
+ super().__init__(cortex_a_formatters)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 640c3e1..ce4e0fc 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -126,12 +125,11 @@ def configure_and_get_ethosu_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Ethos-U advisor."""
if context.event_handlers is None:
- context.event_handlers = [EthosUEventHandler(output)]
+ context.event_handlers = [EthosUEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(
diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py
index 91f6015..9873014 100644
--- a/src/mlia/target/ethos_u/handlers.py
+++ b/src/mlia/target/ethos_u/handlers.py
@@ -6,7 +6,6 @@ from __future__ import annotations
import logging
from mlia.backend.vela.compat import Operators
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.target.ethos_u.events import EthosUAdvisorEventHandler
@@ -21,9 +20,9 @@ logger = logging.getLogger(__name__)
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(ethos_u_formatters, output)
+ super().__init__(ethos_u_formatters)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 0da44db..b60e824 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -81,12 +80,11 @@ def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
if context.event_handlers is None:
- context.event_handlers = [TOSAEventHandler(output)]
+ context.event_handlers = [TOSAEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(model, target_profile)
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
index 7f80f77..f222823 100644
--- a/src/mlia/target/tosa/handlers.py
+++ b/src/mlia/target/tosa/handlers.py
@@ -7,7 +7,6 @@ from __future__ import annotations
import logging
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.target.tosa.events import TOSAAdvisorEventHandler
@@ -20,9 +19,9 @@ logger = logging.getLogger(__name__)
class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
"""Event handler for TOSA advisor."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(tosa_formatters, output)
+ super().__init__(tosa_formatters)
def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
"""Handle TOSAAdvisorStartedEvent event."""
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
index 0659dcf..17f6cae 100644
--- a/src/mlia/utils/logging.py
+++ b/src/mlia/utils/logging.py
@@ -18,6 +18,8 @@ from typing import Generator
from typing import Iterable
from typing import TextIO
+from mlia.utils.console import remove_ascii_codes
+
class LoggerWriter:
"""Redirect printed messages to the logger."""
@@ -152,12 +154,21 @@ class LogFilter(logging.Filter):
return cls(skip_by_level)
+class NoASCIIFormatter(logging.Formatter):
+ """Custom Formatter for logging into file."""
+
+ def format(self, record: logging.LogRecord) -> str:
+ """Overwrite format method to remove ascii codes from record."""
+ result = super().format(record)
+ return remove_ascii_codes(result)
+
+
def create_log_handler(
*,
file_path: Path | None = None,
stream: Any | None = None,
log_level: int | None = None,
- log_format: str | None = None,
+ log_format: str | logging.Formatter | None = None,
log_filter: logging.Filter | None = None,
delay: bool = True,
) -> logging.Handler:
@@ -176,7 +187,9 @@ def create_log_handler(
handler.setLevel(log_level)
if log_format:
- handler.setFormatter(logging.Formatter(log_format))
+ if isinstance(log_format, str):
+ log_format = logging.Formatter(log_format)
+ handler.setFormatter(log_format)
if log_filter:
handler.addFilter(log_filter)
diff --git a/tests/test_api.py b/tests/test_api.py
index 0bbc3ae..251d5ac 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -20,7 +20,11 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
with pytest.raises(Exception, match="Target profile is not provided"):
- get_advice(None, test_keras_model, {"compatibility"}) # type: ignore
+ get_advice(
+ None, # type:ignore
+ test_keras_model,
+ {"compatibility"},
+ )
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py
index 5a80b4b..0b6eaf5 100644
--- a/tests/test_backend_tosa_compat.py
+++ b/tests/test_backend_tosa_compat.py
@@ -27,7 +27,7 @@ def replace_get_tosa_checker_with_mock(
def test_compatibility_check_should_fail_if_checker_not_available(
- monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: str | Path
) -> None:
"""Test that compatibility check should fail if TOSA checker is not available."""
replace_get_tosa_checker_with_mock(monkeypatch, None)
@@ -71,7 +71,7 @@ def test_compatibility_check_should_fail_if_checker_not_available(
)
def test_get_tosa_compatibility_info(
monkeypatch: pytest.MonkeyPatch,
- test_tflite_model: Path,
+ test_tflite_model: str | Path,
is_tosa_compatible: bool,
operators: Any,
exception: Exception | None,
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 5a9c0c9..9db5341 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import argparse
from functools import wraps
-from pathlib import Path
from typing import Any
from typing import Callable
from unittest.mock import ANY
@@ -122,8 +121,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.tflite",
compatibility=False,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
@@ -135,8 +132,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.tflite",
compatibility=False,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
@@ -153,8 +148,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- output=None,
- json=False,
compatibility=True,
performance=True,
backend=None,
@@ -167,9 +160,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
"--performance",
"--target-profile",
"ethos-u55-256",
- "--output",
- "result.json",
- "--json",
],
call(
ctx=ANY,
@@ -177,8 +167,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
performance=True,
compatibility=False,
- output=Path("result.json"),
- json=True,
backend=None,
),
],
@@ -196,8 +184,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
compatibility=False,
performance=True,
- output=None,
- json=False,
backend=None,
),
],
@@ -218,8 +204,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=None,
clustering_target=None,
- output=None,
- json=False,
backend=None,
),
],
@@ -244,8 +228,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=0.5,
clustering_target=32,
- output=None,
- json=False,
backend=None,
),
],
@@ -267,8 +249,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=False,
pruning_target=None,
clustering_target=None,
- output=None,
- json=False,
backend=["some_backend"],
),
],
@@ -286,8 +266,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
compatibility=True,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index a889a93..94c3111 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -5,16 +5,14 @@ from __future__ import annotations
import argparse
from contextlib import ExitStack as does_not_raise
-from pathlib import Path
from typing import Any
import pytest
-from mlia.cli.options import add_output_options
+from mlia.cli.options import get_output_format
from mlia.cli.options import get_target_profile_opts
from mlia.cli.options import parse_optimization_parameters
-from mlia.cli.options import parse_output_parameters
-from mlia.core.common import FormattedFilePath
+from mlia.core.typing import OutputFormat
@pytest.mark.parametrize(
@@ -164,54 +162,24 @@ def test_get_target_opts(args: dict | None, expected_opts: list[str]) -> None:
@pytest.mark.parametrize(
- "output_parameters, expected_path",
- [
- [["--output", "report.json"], "report.json"],
- [["--output", "REPORT.JSON"], "REPORT.JSON"],
- [["--output", "some_folder/report.json"], "some_folder/report.json"],
- ],
-)
-def test_output_options(output_parameters: list[str], expected_path: str) -> None:
- """Test output options resolving."""
- parser = argparse.ArgumentParser()
- add_output_options(parser)
-
- args = parser.parse_args(output_parameters)
- assert str(args.output) == expected_path
-
-
-@pytest.mark.parametrize(
- "path, json, expected_error, output",
+ "args, expected_output_format",
[
[
- None,
- True,
- pytest.raises(
- argparse.ArgumentError,
- match=r"To enable JSON output you need to specify the output path. "
- r"\(e.g. --output out.json --json\)",
- ),
- None,
+ {},
+ "plain_text",
],
- [None, False, does_not_raise(), None],
[
- Path("test_path"),
- False,
- does_not_raise(),
- FormattedFilePath(Path("test_path"), "plain_text"),
+ {"json": True},
+ "json",
],
[
- Path("test_path"),
- True,
- does_not_raise(),
- FormattedFilePath(Path("test_path"), "json"),
+ {"json": False},
+ "plain_text",
],
],
)
-def test_parse_output_parameters(
- path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None
-) -> None:
- """Test parsing for output parameters."""
- with expected_error:
- formatted_output = parse_output_parameters(path, json)
- assert formatted_output == output
+def test_get_output_format(args: dict, expected_output_format: OutputFormat) -> None:
+ """Test get_output_format function."""
+ arguments = argparse.Namespace(**args)
+ output_format = get_output_format(arguments)
+ assert output_format == expected_output_format
diff --git a/tests/test_core_context.py b/tests/test_core_context.py
index dcdbef3..0e7145f 100644
--- a/tests/test_core_context.py
+++ b/tests/test_core_context.py
@@ -58,6 +58,7 @@ def test_execution_context(tmpdir: str) -> None:
verbose=True,
logs_dir="logs_directory",
models_dir="models_directory",
+ output_format="json",
)
assert context.advice_category == category
@@ -68,12 +69,14 @@ def test_execution_context(tmpdir: str) -> None:
expected_model_path = Path(tmpdir) / "models_directory/sample.model"
assert context.get_model_path("sample.model") == expected_model_path
assert context.verbose is True
+ assert context.output_format == "json"
assert str(context) == (
f"ExecutionContext: "
f"working_dir={tmpdir}, "
"advice_category={'COMPATIBILITY'}, "
"config_parameters={'param': 'value'}, "
- "verbose=True"
+ "verbose=True, "
+ "output_format=json"
)
context_with_default_params = ExecutionContext(working_dir=tmpdir)
@@ -88,11 +91,13 @@ def test_execution_context(tmpdir: str) -> None:
default_model_path = context_with_default_params.get_model_path("sample.model")
expected_default_model_path = Path(tmpdir) / "models/sample.model"
assert default_model_path == expected_default_model_path
+ assert context_with_default_params.output_format == "plain_text"
expected_str = (
f"ExecutionContext: working_dir={tmpdir}, "
"advice_category={'COMPATIBILITY'}, "
"config_parameters=None, "
- "verbose=False"
+ "verbose=False, "
+ "output_format=plain_text"
)
assert str(context_with_default_params) == expected_str
diff --git a/tests/test_cli_logging.py b/tests/test_core_logging.py
index 1e2cc85..e021e26 100644
--- a/tests/test_cli_logging.py
+++ b/tests/test_core_logging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module cli.logging."""
from __future__ import annotations
@@ -8,7 +8,7 @@ from pathlib import Path
import pytest
-from mlia.cli.logging import setup_logging
+from mlia.core.logging import setup_logging
from tests.utils.logging import clear_loggers
@@ -33,20 +33,21 @@ def teardown_function() -> None:
(
None,
True,
- """mlia.backend.manager - backends debug
-cli info
-mlia.cli - cli debug
+ """mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
+mlia.cli - DEBUG - cli debug
""",
None,
),
(
"logs",
True,
- """mlia.backend.manager - backends debug
-cli info
-mlia.cli - cli debug
+ """mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
+mlia.cli - DEBUG - cli debug
""",
"""mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
mlia.cli - DEBUG - cli debug
""",
),
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index 71eaf85..7a68d4b 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -3,9 +3,13 @@
"""Tests for reporting module."""
from __future__ import annotations
-import io
import json
from enum import Enum
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+from unittest.mock import Mock
+from unittest.mock import patch
import numpy as np
import pytest
@@ -14,13 +18,15 @@ from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
from mlia.core.reporting import Column
+from mlia.core.reporting import CustomJSONEncoder
from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
-from mlia.core.reporting import json_reporter
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.core.reporting import TextReporter
from mlia.utils.console import remove_ascii_codes
@@ -364,10 +370,9 @@ def test_custom_json_serialization() -> None:
alias="sample_table",
)
- output = io.StringIO()
- json_reporter(table, output)
+ output = json.dumps(table.to_json(), indent=4, cls=CustomJSONEncoder)
- assert json.loads(output.getvalue()) == {
+ assert json.loads(output) == {
"sample_table": [
{"column1": "value1"},
{"column1": 10.0},
@@ -375,3 +380,93 @@ def test_custom_json_serialization() -> None:
{"column1": 10},
]
}
+
+
+class TestTextReporter:
+ """Test TextReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test TextReporter."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ assert reporter.output_format == "plain_text"
+
+ def test_submit(self) -> None:
+ """Test TextReporter submit."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2", delay_print=True)
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+ assert reporter.delayed == [("test2", ANY)]
+
+ def test_print_delayed(self) -> None:
+ """Test TextReporter print_delayed."""
+ with patch(
+ "mlia.core.reporting.TextReporter.produce_report"
+ ) as mock_produce_report:
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test", delay_print=True)
+ reporter.print_delayed()
+ assert reporter.data == [("test", ANY)]
+ assert not reporter.delayed
+ mock_produce_report.assert_called()
+
+ def test_produce_report(self) -> None:
+ """Test TextReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+
+ with patch("mlia.core.reporting.logger") as mock_logger:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_plain_text()])
+ mock_logger.info.assert_called()
+
+
+class TestJSONReporter:
+ """Test JSONReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test JSONReporter."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ assert reporter.output_format == "json"
+
+ def test_submit(self) -> None:
+ """Test JSONReporter submit."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2")
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+
+ def test_generate_report(self) -> None:
+ """Test JSONReporter generate_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+
+ with patch(
+ "mlia.core.reporting.JSONReporter.produce_report"
+ ) as mock_produce_report:
+ reporter.generate_report()
+ mock_produce_report.assert_called()
+
+ @patch("builtins.print")
+ def test_produce_report(self, mock_print: Mock) -> None:
+ """Test JSONReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+
+ with patch("json.dumps") as mock_dumps:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_json()])
+ mock_dumps.assert_called()
+ mock_print.assert_called()
diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py
index 7f372bf..ee7ea52 100644
--- a/tests/test_target_ethos_u_reporters.py
+++ b/tests/test_target_ethos_u_reporters.py
@@ -1,106 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reports module."""
from __future__ import annotations
-import json
-import sys
-from contextlib import ExitStack as doesnt_raise
-from pathlib import Path
-from typing import Any
-from typing import Callable
-from typing import Literal
-
import pytest
from mlia.backend.vela.compat import NpuSupported
from mlia.backend.vela.compat import Operator
-from mlia.backend.vela.compat import Operators
-from mlia.core.reporting import get_reporter
-from mlia.core.reporting import produce_report
from mlia.core.reporting import Report
-from mlia.core.reporting import Reporter
from mlia.core.reporting import Table
from mlia.target.ethos_u.config import EthosUConfiguration
-from mlia.target.ethos_u.performance import MemoryUsage
-from mlia.target.ethos_u.performance import NPUCycles
-from mlia.target.ethos_u.performance import PerformanceMetrics
-from mlia.target.ethos_u.reporters import ethos_u_formatters
from mlia.target.ethos_u.reporters import report_device_details
from mlia.target.ethos_u.reporters import report_operators
-from mlia.target.ethos_u.reporters import report_perf_metrics
from mlia.utils.console import remove_ascii_codes
@pytest.mark.parametrize(
- "data, formatters",
- [
- (
- [Operator("test_operator", "test_type", NpuSupported(False, []))],
- [report_operators],
- ),
- (
- PerformanceMetrics(
- EthosUConfiguration("ethos-u55-256"),
- NPUCycles(0, 0, 0, 0, 0, 0),
- MemoryUsage(0, 0, 0, 0, 0),
- ),
- [report_perf_metrics],
- ),
- ],
-)
-@pytest.mark.parametrize(
- "fmt, output, expected_error",
- [
- [
- "unknown_format",
- sys.stdout,
- pytest.raises(Exception, match="Unknown format unknown_format"),
- ],
- [
- "plain_text",
- sys.stdout,
- doesnt_raise(),
- ],
- [
- "json",
- sys.stdout,
- doesnt_raise(),
- ],
- [
- "plain_text",
- "report.txt",
- doesnt_raise(),
- ],
- [
- "json",
- "report.json",
- doesnt_raise(),
- ],
- ],
-)
-def test_report(
- data: Any,
- formatters: list[Callable],
- fmt: Literal["plain_text", "json"],
- output: Any,
- expected_error: Any,
- tmp_path: Path,
-) -> None:
- """Test report function."""
- if is_file := isinstance(output, str):
- output = tmp_path / output
-
- for formatter in formatters:
- with expected_error:
- produce_report(data, formatter, fmt, output)
-
- if is_file:
- assert output.is_file()
- assert output.stat().st_size > 0
-
-
-@pytest.mark.parametrize(
"ops, expected_plain_text, expected_json_dict",
[
(
@@ -314,40 +229,3 @@ def test_report_device_details(
json_dict = report.to_json()
assert json_dict == expected_json_dict
-
-
-def test_get_reporter(tmp_path: Path) -> None:
- """Test reporter functionality."""
- ops = Operators(
- [
- Operator(
- "npu_supported",
- "op_type",
- NpuSupported(True, []),
- ),
- ]
- )
-
- output = tmp_path / "output.json"
- with get_reporter("json", output, ethos_u_formatters) as reporter:
- assert isinstance(reporter, Reporter)
-
- with pytest.raises(
- Exception, match="Unable to find appropriate formatter for some_data"
- ):
- reporter.submit("some_data")
-
- reporter.submit(ops)
-
- with open(output, encoding="utf-8") as file:
- json_data = json.load(file)
-
- assert json_data == {
- "operators_stats": [
- {
- "npu_unsupported_ratio": 0.0,
- "num_of_npu_supported_operators": 1,
- "num_of_operators": 1,
- }
- ]
- }
diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py
index 35fd707..beddaed 100644
--- a/tests_e2e/test_e2e.py
+++ b/tests_e2e/test_e2e.py
@@ -16,6 +16,7 @@ from pathlib import Path
from typing import Any
from typing import Generator
from typing import Iterable
+from typing import Sequence
import pytest
@@ -75,30 +76,48 @@ class ExecutionConfiguration:
)
-def launch_and_wait(cmd: list[str], stdin: Any | None = None) -> None:
+def launch_and_wait(
+ cmd: list[str],
+ output_file: Path | None = None,
+ print_output: bool = True,
+ stdin: Any | None = None,
+) -> None:
"""Launch command and wait for the completion."""
with subprocess.Popen( # nosec
cmd,
stdin=stdin,
stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT, # redirect command stderr to stdout
+ stderr=subprocess.PIPE,
universal_newlines=True,
bufsize=1,
) as process:
- if process.stdout is None:
+ # Process stdout
+ if process.stdout:
+ # Store the output in a variable
+ output = process.stdout.read()
+ # Save the output into a file
+ if output_file:
+ output_file.write_text(output)
+ print(f"Output saved to {output_file}")
+ # Show the output to stdout
+ if print_output:
+ print(output)
+ else:
raise Exception("Unable to get process output")
- # redirect output of the process into current process stdout
- for line in process.stdout:
- print(line, end="")
-
+ # Wait for the process to terminate
process.wait()
if (exit_code := process.poll()) != 0:
raise Exception(f"Command failed with exit_code {exit_code}")
-def run_command(cmd: list[str], cmd_input: str | None = None) -> None:
+def run_command(
+ cmd: list[str],
+ output_file: Path | None = None,
+ print_output: bool = True,
+ cmd_input: str | None = None,
+) -> None:
"""Run command."""
print(f"Run command: {' '.join(cmd)}")
@@ -118,7 +137,7 @@ def run_command(cmd: list[str], cmd_input: str | None = None) -> None:
cmd_input_file.write(cmd_input)
cmd_input_file.seek(0)
- launch_and_wait(cmd, cmd_input_file)
+ launch_and_wait(cmd, output_file, print_output, cmd_input_file)
def get_config_file() -> Path:
@@ -209,10 +228,15 @@ def get_config_content(config_file: Path) -> Any:
executions = json_data.get("executions", [])
assert is_list_of(executions, dict), "List of the dictionaries expected"
- return executions
+ settings = json_data.get("settings", {})
+ assert isinstance(settings, dict)
+
+ return settings, executions
-def get_all_commands_combinations(executions: Any) -> Generator[list[str], None, None]:
+def get_all_commands_combinations(
+ executions: Any,
+) -> Generator[dict[str, Sequence[str]], None, None]:
"""Return all commands combinations."""
exec_configs = (
ExecutionConfiguration.from_dict(exec_info) for exec_info in executions
@@ -221,13 +245,12 @@ def get_all_commands_combinations(executions: Any) -> Generator[list[str], None,
parser = get_args_parser()
for exec_config in exec_configs:
for command_combination in exec_config.all_combinations:
- for idx, param in enumerate(command_combination):
- if "{model_name}" in param:
- args = parser.parse_args(command_combination)
- model_name = Path(args.model).stem
- param = param.replace("{model_name}", model_name)
- command_combination[idx] = param
- yield command_combination
+ args = parser.parse_args(command_combination)
+ model_name = Path(args.model).stem
+ yield {
+ "model_name": model_name,
+ "command_combination": command_combination,
+ }
def check_args(args: list[str], no_skip: bool) -> None:
@@ -249,21 +272,31 @@ def check_args(args: list[str], no_skip: bool) -> None:
pytest.skip(f"Missing backend(s): {','.join(missing_backends)}")
-def get_execution_definitions() -> Generator[list[str], None, None]:
+def get_execution_definitions(
+ executions: dict,
+) -> Generator[dict[str, Sequence[str]], None, None]:
"""Collect all execution definitions from configuration file."""
- config_file = get_config_file()
- executions = get_config_content(config_file)
- executions = resolve_parameters(executions)
-
- return get_all_commands_combinations(executions)
+ resolved_executions = resolve_parameters(executions)
+ return get_all_commands_combinations(resolved_executions)
class TestEndToEnd:
"""End to end command tests."""
- @pytest.mark.parametrize("command", get_execution_definitions(), ids=str)
- def test_e2e(self, command: list[str], no_skip: bool) -> None:
+ configuration_file = get_config_file()
+ settings, executions = get_config_content(configuration_file)
+
+ @pytest.mark.parametrize(
+ "command_data", get_execution_definitions(executions), ids=str
+ )
+ def test_e2e(self, command_data: dict[str, list[str]], no_skip: bool) -> None:
"""Test MLIA command with the provided parameters."""
+ command = command_data["command_combination"]
+ model_name = command_data["model_name"]
check_args(command, no_skip)
mlia_command = ["mlia", *command]
- run_command(mlia_command)
+ print_output = self.settings.get("print_output", True)
+ output_file = self.settings.get("output_file", None)
+ if output_file:
+ output_file = Path(output_file.replace("{model_name}", model_name))
+ run_command(mlia_command, output_file, print_output)