aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mlia/api.py13
-rw-r--r--src/mlia/backend/tosa_checker/compat.py6
-rw-r--r--src/mlia/cli/commands.py24
-rw-r--r--src/mlia/cli/main.py9
-rw-r--r--src/mlia/cli/options.py35
-rw-r--r--src/mlia/core/common.py33
-rw-r--r--src/mlia/core/context.py16
-rw-r--r--src/mlia/core/handlers.py31
-rw-r--r--src/mlia/core/logging.py (renamed from src/mlia/cli/logging.py)61
-rw-r--r--src/mlia/core/reporting.py166
-rw-r--r--src/mlia/core/typing.py7
-rw-r--r--src/mlia/core/workflow.py3
-rw-r--r--src/mlia/target/cortex_a/advisor.py4
-rw-r--r--src/mlia/target/cortex_a/handlers.py5
-rw-r--r--src/mlia/target/ethos_u/advisor.py4
-rw-r--r--src/mlia/target/ethos_u/handlers.py5
-rw-r--r--src/mlia/target/tosa/advisor.py4
-rw-r--r--src/mlia/target/tosa/handlers.py5
-rw-r--r--src/mlia/utils/logging.py17
-rw-r--r--tests/test_api.py6
-rw-r--r--tests/test_backend_tosa_compat.py4
-rw-r--r--tests/test_cli_main.py22
-rw-r--r--tests/test_cli_options.py60
-rw-r--r--tests/test_core_context.py9
-rw-r--r--tests/test_core_logging.py (renamed from tests/test_cli_logging.py)17
-rw-r--r--tests/test_core_reporting.py105
-rw-r--r--tests/test_target_ethos_u_reporters.py124
-rw-r--r--tests_e2e/test_e2e.py87
28 files changed, 394 insertions, 488 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index 2cabf37..8105276 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -9,7 +9,6 @@ from typing import Any
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import ExecutionContext
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
@@ -24,7 +23,6 @@ def get_advice(
model: str | Path,
category: set[str],
optimization_targets: list[dict[str, Any]] | None = None,
- output: FormattedFilePath | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
) -> None:
@@ -42,8 +40,6 @@ def get_advice(
category "compatibility" is used by default.
:param optimization_targets: optional model optimization targets that
could be used for generating advice in "optimization" category.
- :param output: path to the report file. If provided, MLIA will save
- report in this location.
:param context: optional parameter which represents execution context,
could be used for advanced use cases
:param backends: A list of backends that should be used for the given
@@ -57,11 +53,9 @@ def get_advice(
>>> get_advice("ethos-u55-256", "path/to/the/model",
{"optimization", "compatibility"})
- Getting the advice for the category "performance" and save result report in file
- "report.json"
+ Getting the advice for the category "performance".
- >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"},
- output=FormattedFilePath("report.json")
+ >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"})
"""
advice_category = AdviceCategory.from_string(category)
@@ -76,7 +70,6 @@ def get_advice(
context,
target_profile,
model,
- output,
optimization_targets=optimization_targets,
backends=backends,
)
@@ -88,7 +81,6 @@ def get_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
@@ -109,6 +101,5 @@ def get_advisor(
context,
target_profile,
model,
- output,
**extra_args,
)
diff --git a/src/mlia/backend/tosa_checker/compat.py b/src/mlia/backend/tosa_checker/compat.py
index 81f3015..1c410d3 100644
--- a/src/mlia/backend/tosa_checker/compat.py
+++ b/src/mlia/backend/tosa_checker/compat.py
@@ -5,12 +5,12 @@ from __future__ import annotations
import sys
from dataclasses import dataclass
+from pathlib import Path
from typing import Any
from typing import cast
from typing import Protocol
from mlia.backend.errors import BackendUnavailableError
-from mlia.core.typing import PathOrFileLike
from mlia.utils.logging import capture_raw_output
@@ -45,7 +45,7 @@ class TOSACompatibilityInfo:
def get_tosa_compatibility_info(
- tflite_model_path: PathOrFileLike,
+ tflite_model_path: str | Path,
) -> TOSACompatibilityInfo:
"""Return list of the operators."""
# Capture the possible exception in running get_tosa_checker
@@ -100,7 +100,7 @@ def get_tosa_compatibility_info(
)
-def get_tosa_checker(tflite_model_path: PathOrFileLike) -> TOSAChecker | None:
+def get_tosa_checker(tflite_model_path: str | Path) -> TOSAChecker | None:
"""Return instance of the TOSA checker."""
try:
import tosa_checker as tc # pylint: disable=import-outside-toplevel
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index d2242ba..c17d571 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -7,10 +7,10 @@ functionality.
Before running them from scripts 'logging' module should
be configured. Function 'setup_logging' from module
-'mli.cli.logging' could be used for that, e.g.
+'mli.core.logging' could be used for that, e.g.
>>> from mlia.api import ExecutionContext
->>> from mlia.cli.logging import setup_logging
+>>> from mlia.core.logging import setup_logging
>>> setup_logging(verbose=True)
>>> import mlia.cli.commands as mlia
>>> mlia.check(ExecutionContext(), "ethos-u55-256",
@@ -27,7 +27,6 @@ from mlia.cli.command_validators import validate_backend
from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
-from mlia.cli.options import parse_output_parameters
from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
@@ -41,8 +40,6 @@ def check(
model: str | None = None,
compatibility: bool = False,
performance: bool = False,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -61,7 +58,6 @@ def check(
:param model: path to the Keras model
:param compatibility: flag that identifies whether to run compatibility checks
:param performance: flag that identifies whether to run performance checks
- :param output: path to the file where the report will be saved
:param backend: list of the backends to use for evaluation
Example:
@@ -69,18 +65,15 @@ def check(
and operator compatibility.
>>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> setup_logging()
>>> from mlia.cli.commands import check
>>> check(ExecutionContext(), "ethos-u55-256",
- "model.h5", compatibility=True, performance=True,
- output="report.json")
+ "model.h5", compatibility=True, performance=True)
"""
if not model:
raise Exception("Model is not provided")
- formatted_output = parse_output_parameters(output, json)
-
# Set category based on checks to perform (i.e. "compatibility" and/or
# "performance").
# If no check type is specified, "compatibility" is the default category.
@@ -98,7 +91,6 @@ def check(
target_profile,
model,
category,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
@@ -113,8 +105,6 @@ def optimize( # pylint: disable=too-many-arguments
pruning_target: float | None,
clustering_target: int | None,
layers_to_optimize: list[str] | None = None,
- output: Path | None = None,
- json: bool = False,
backend: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -133,15 +123,13 @@ def optimize( # pylint: disable=too-many-arguments
:param pruning_target: pruning optimization target
:param layers_to_optimize: list of the layers of the model which should be
optimized, if None then all layers are used
- :param output: path to the file where the report will be saved
- :param json: set the output format to json
:param backend: list of the backends to use for evaluation
Example:
Run command for the target profile ethos-u55-256 and
the provided TensorFlow Lite model and print report on the standard output
- >>> from mlia.cli.logging import setup_logging
+ >>> from mlia.core.logging import setup_logging
>>> from mlia.api import ExecutionContext
>>> setup_logging()
>>> from mlia.cli.commands import optimize
@@ -161,7 +149,6 @@ def optimize( # pylint: disable=too-many-arguments
)
)
- formatted_output = parse_output_parameters(output, json)
validated_backend = validate_backend(target_profile, backend)
get_advice(
@@ -169,7 +156,6 @@ def optimize( # pylint: disable=too-many-arguments
model,
{"optimization"},
optimization_targets=opt_params,
- output=formatted_output,
context=ctx,
backends=validated_backend,
)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 1102d45..7ce7dc9 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -19,7 +19,6 @@ from mlia.cli.commands import check
from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
-from mlia.cli.logging import setup_logging
from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
@@ -30,9 +29,11 @@ from mlia.cli.options import add_model_options
from mlia.cli.options import add_multi_optimization_options
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
+from mlia.cli.options import get_output_format
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
+from mlia.core.logging import setup_logging
from mlia.target.registry import registry as target_registry
@@ -162,10 +163,13 @@ def setup_context(
ctx = ExecutionContext(
verbose="debug" in args and args.debug,
action_resolver=CLIActionResolver(vars(args)),
+ output_format=get_output_format(args),
)
+ setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format)
+
# these parameters should not be passed into command function
- skipped_params = ["func", "command", "debug"]
+ skipped_params = ["func", "command", "debug", "json"]
# pass these parameters only if command expects them
expected_params = [context_var_name]
@@ -186,7 +190,6 @@ def setup_context(
def run_command(args: argparse.Namespace) -> int:
"""Run command."""
ctx, func_args = setup_context(args)
- setup_logging(ctx.logs_path, ctx.verbose)
logger.debug(
"*** This is the beginning of the command '%s' execution ***", args.command
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 5aca3b3..e01f107 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,7 +12,7 @@ from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
from mlia.cli.config import DEFAULT_PRUNING_TARGET
from mlia.cli.config import get_available_backends
from mlia.cli.config import is_corstone_backend
-from mlia.core.common import FormattedFilePath
+from mlia.core.typing import OutputFormat
from mlia.utils.filesystem import get_supported_profile_names
@@ -89,16 +89,9 @@ def add_output_options(parser: argparse.ArgumentParser) -> None:
"""Add output specific options."""
output_group = parser.add_argument_group("output options")
output_group.add_argument(
- "-o",
- "--output",
- type=Path,
- help=("Name of the file where the report will be saved."),
- )
-
- output_group.add_argument(
"--json",
action="store_true",
- help=("Format to use for the output (requires --output argument to be set)."),
+ help=("Print the output in JSON format."),
)
@@ -209,22 +202,6 @@ def add_backend_options(
)
-def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None:
- """Parse and return path and file format as FormattedFilePath."""
- if not path and json:
- raise argparse.ArgumentError(
- None,
- "To enable JSON output you need to specify the output path. "
- "(e.g. --output out.json --json)",
- )
- if not path:
- return None
- if json:
- return FormattedFilePath(path, "json")
-
- return FormattedFilePath(path, "plain_text")
-
-
def parse_optimization_parameters(
pruning: bool = False,
clustering: bool = False,
@@ -301,3 +278,11 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]:
for name in non_default
for item in construct_param(params_name[name], device_args[name])
]
+
+
+def get_output_format(args: argparse.Namespace) -> OutputFormat:
+ """Return the OutputFormat depending on the CLI flags."""
+ output_format: OutputFormat = "plain_text"
+ if "json" in args and args.json:
+ output_format = "json"
+ return output_format
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index 53df001..baaed50 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -13,9 +13,6 @@ from enum import auto
from enum import Flag
from typing import Any
-from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
-
# This type is used as type alias for the items which are being passed around
# in advisor workflow. There are no restrictions on the type of the
# object. This alias used only to emphasize the nature of the input/output
@@ -23,36 +20,6 @@ from mlia.core.typing import PathOrFileLike
DataItem = Any
-class FormattedFilePath:
- """Class used to keep track of the format that a path points to."""
-
- def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None:
- """Init FormattedFilePath."""
- self._path = path
- self._fmt = fmt
-
- @property
- def fmt(self) -> OutputFormat:
- """Return file format."""
- return self._fmt
-
- @property
- def path(self) -> PathOrFileLike:
- """Return file path."""
- return self._path
-
- def __eq__(self, other: object) -> bool:
- """Check for equality with other objects."""
- if isinstance(other, FormattedFilePath):
- return other.fmt == self.fmt and other.path == self.path
-
- return False
-
- def __repr__(self) -> str:
- """Represent object."""
- return f"FormattedFilePath {self.path=}, {self.fmt=}"
-
-
class AdviceCategory(Flag):
"""Advice category.
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index 94aa885..f8442a3 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -23,6 +23,7 @@ from mlia.core.events import EventHandler
from mlia.core.events import EventPublisher
from mlia.core.helpers import ActionResolver
from mlia.core.helpers import APIActionResolver
+from mlia.core.typing import OutputFormat
logger = logging.getLogger(__name__)
@@ -68,6 +69,11 @@ class Context(ABC):
def action_resolver(self) -> ActionResolver:
"""Return action resolver."""
+ @property
+ @abstractmethod
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+
@abstractmethod
def update(
self,
@@ -106,6 +112,7 @@ class ExecutionContext(Context):
logs_dir: str = "logs",
models_dir: str = "models",
action_resolver: ActionResolver | None = None,
+ output_format: OutputFormat = "plain_text",
) -> None:
"""Init execution context.
@@ -139,6 +146,7 @@ class ExecutionContext(Context):
self.logs_dir = logs_dir
self.models_dir = models_dir
self._action_resolver = action_resolver or APIActionResolver()
+ self._output_format = output_format
@property
def working_dir(self) -> Path:
@@ -197,6 +205,11 @@ class ExecutionContext(Context):
"""Return path to the logs directory."""
return self._working_dir_path / self.logs_dir
+ @property
+ def output_format(self) -> OutputFormat:
+ """Return the output format."""
+ return self._output_format
+
def update(
self,
*,
@@ -221,7 +234,8 @@ class ExecutionContext(Context):
f"ExecutionContext: working_dir={self._working_dir_path}, "
f"advice_category={category}, "
f"config_parameters={self.config_parameters}, "
- f"verbose={self.verbose}"
+ f"verbose={self.verbose}, "
+ f"output_format={self.output_format}"
)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index 6e50934..24d4881 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -9,7 +9,6 @@ from typing import Callable
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
-from mlia.core.common import FormattedFilePath
from mlia.core.events import ActionFinishedEvent
from mlia.core.events import ActionStartedEvent
from mlia.core.events import AdviceStageFinishedEvent
@@ -25,9 +24,11 @@ from mlia.core.events import EventDispatcher
from mlia.core.events import ExecutionFailedEvent
from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
+from mlia.core.mixins import ContextMixin
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
-from mlia.core.typing import PathOrFileLike
+from mlia.core.reporting import TextReporter
from mlia.utils.console import create_section_header
@@ -92,26 +93,27 @@ _ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
_MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
_MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
_ADV_GENERATION_MSG = create_section_header("Advice Generation")
-_REPORT_GENERATION_MSG = create_section_header("Report Generation")
-class WorkflowEventsHandler(SystemEventsHandler):
+class WorkflowEventsHandler(SystemEventsHandler, ContextMixin):
"""Event handler for the system events."""
+ reporter: Reporter
+
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: FormattedFilePath | None = None,
) -> None:
"""Init event handler."""
- output_format = output.fmt if output else "plain_text"
- self.reporter = Reporter(formatter_resolver, output_format)
- self.output = output.path if output else None
-
+ self.formatter_resolver = formatter_resolver
self.advice: list[Advice] = []
def on_execution_started(self, event: ExecutionStartedEvent) -> None:
"""Handle ExecutionStarted event."""
+ if self.context.output_format == "json":
+ self.reporter = JSONReporter(self.formatter_resolver)
+ else:
+ self.reporter = TextReporter(self.formatter_resolver)
logger.info(_ADV_EXECUTION_STARTED)
def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
@@ -132,12 +134,6 @@ class WorkflowEventsHandler(SystemEventsHandler):
"""Handle DataCollectorSkipped event."""
logger.info("Skipped: %s", event.reason)
- @staticmethod
- def report_generated(output: PathOrFileLike) -> None:
- """Log report generation."""
- logger.info(_REPORT_GENERATION_MSG)
- logger.info("Report(s) and advice list saved to: %s", output)
-
def on_data_analysis_stage_finished(
self, event: DataAnalysisStageFinishedEvent
) -> None:
@@ -160,7 +156,4 @@ class WorkflowEventsHandler(SystemEventsHandler):
table_style="no_borders",
)
- self.reporter.generate_report(self.output)
-
- if self.output is not None:
- self.report_generated(self.output)
+ self.reporter.generate_report()
diff --git a/src/mlia/cli/logging.py b/src/mlia/core/logging.py
index 5c5c4b8..686f8ab 100644
--- a/src/mlia/cli/logging.py
+++ b/src/mlia/core/logging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI logging configuration."""
from __future__ import annotations
@@ -8,18 +8,20 @@ import sys
from pathlib import Path
from typing import Iterable
+from mlia.core.typing import OutputFormat
from mlia.utils.logging import attach_handlers
from mlia.utils.logging import create_log_handler
-from mlia.utils.logging import LogFilter
+from mlia.utils.logging import NoASCIIFormatter
-_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s"
+_CONSOLE_DEBUG_FORMAT = "%(name)s - %(levelname)s - %(message)s"
_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
def setup_logging(
logs_dir: str | Path | None = None,
verbose: bool = False,
+ output_format: OutputFormat = "plain_text",
log_filename: str = "mlia.log",
) -> None:
"""Set up logging.
@@ -30,6 +32,8 @@ def setup_logging(
debug information. If the path is not provided then no log files will
be created during execution
:param verbose: enable extended logging for the tools loggers
+ :param output_format: specify the out format needed for setting up the right
+ logging system
:param log_filename: name of the log file in the logs directory
"""
mlia_logger, tensorflow_logger, py_warnings_logger = (
@@ -42,7 +46,7 @@ def setup_logging(
for logger in [mlia_logger, tensorflow_logger]:
logger.setLevel(logging.DEBUG)
- mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
+ mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose, output_format)
attach_handlers(mlia_handlers, [mlia_logger])
tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
@@ -53,28 +57,47 @@ def _get_mlia_handlers(
logs_dir: str | Path | None,
log_filename: str,
verbose: bool,
+ output_format: OutputFormat,
) -> Iterable[logging.Handler]:
"""Get handlers for the MLIA loggers."""
- yield create_log_handler(
- stream=sys.stdout,
- log_level=logging.INFO,
- )
-
- if verbose:
- mlia_verbose_handler = create_log_handler(
- stream=sys.stdout,
- log_level=logging.DEBUG,
- log_format=_CONSOLE_DEBUG_FORMAT,
- log_filter=LogFilter.equals(logging.DEBUG),
+ # MLIA needs output to standard output via the logging system only when the
+ # format is plain text. When the user specifies the "json" output format,
+ # MLIA disables completely the logging system for the console output and it
+ # relies on the print() function. This is needed because the output might
+ # be corrupted with spurious messages in the standard output.
+ if output_format == "plain_text":
+ if verbose:
+ log_level = logging.DEBUG
+ log_format = _CONSOLE_DEBUG_FORMAT
+ else:
+ log_level = logging.INFO
+ log_format = None
+
+ # Create log handler for stdout
+ yield create_log_handler(
+ stream=sys.stdout, log_level=log_level, log_format=log_format
+ )
+ else:
+ # In case of non plain text output, we need to inform the user if an
+ # error happens during execution.
+ yield create_log_handler(
+ stream=sys.stderr,
+ log_level=logging.ERROR,
)
- yield mlia_verbose_handler
+ # If the logs directory is specified, MLIA stores all output (according to
+ # the logging level) into the file and removing the colouring of the
+ # console output.
if logs_dir:
+ if verbose:
+ log_level = logging.DEBUG
+ else:
+ log_level = logging.INFO
+
yield create_log_handler(
file_path=_get_log_file(logs_dir, log_filename),
- log_level=logging.DEBUG,
- log_format=_FILE_DEBUG_FORMAT,
- log_filter=LogFilter.skip(logging.INFO),
+ log_level=log_level,
+ log_format=NoASCIIFormatter(fmt=_FILE_DEBUG_FORMAT),
delay=True,
)
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index ad63d62..7b9ce5c 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -8,30 +8,21 @@ import logging
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
-from contextlib import contextmanager
-from contextlib import ExitStack
from dataclasses import dataclass
from enum import Enum
from functools import partial
-from io import TextIOWrapper
-from pathlib import Path
from textwrap import fill
from textwrap import indent
from typing import Any
from typing import Callable
-from typing import cast
from typing import Collection
-from typing import Generator
from typing import Iterable
import numpy as np
-from mlia.core.typing import FileLike
from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
from mlia.utils.console import apply_style
from mlia.utils.console import produce_table
-from mlia.utils.logging import LoggerWriter
from mlia.utils.types import is_list_of
logger = logging.getLogger(__name__)
@@ -505,76 +496,48 @@ class CustomJSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o)
-def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in json format."""
- json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder)
- print(json_str, file=output)
-
-
-def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
- """Produce report in text format."""
- print(report.to_plain_text(**kwargs), file=output)
+class Reporter(ABC):
+ """Reporter class."""
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ self.formatter_resolver = formatter_resolver
+ self.data: list[tuple[Any, Callable[[Any], Report]]] = []
-def produce_report(
- data: Any,
- formatter: Callable[[Any], Report],
- fmt: OutputFormat = "plain_text",
- output: PathOrFileLike | None = None,
- **kwargs: Any,
-) -> None:
- """Produce report based on provided data."""
- # check if provided format value is supported
- formats = {"json": json_reporter, "plain_text": text_reporter}
- if fmt not in formats:
- raise Exception(f"Unknown format {fmt}")
+ @abstractmethod
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
- if output is None:
- output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO))
+ def print_delayed(self) -> None:
+ """Print delayed reports."""
- with ExitStack() as exit_stack:
- if isinstance(output, (str, Path)):
- # open file and add it to the ExitStack context manager
- # in that case it will be automatically closed
- stream = exit_stack.enter_context(open(output, "w", encoding="utf-8"))
- else:
- stream = cast(TextIOWrapper, output)
+ def generate_report(self) -> None:
+ """Generate report."""
- # convert data into serializable form
- formatted_data = formatter(data)
- # find handler for the format
- format_handler = formats[fmt]
- # produce report in requested format
- format_handler(formatted_data, stream, **kwargs)
+ @abstractmethod
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
-class Reporter:
+class TextReporter(Reporter):
"""Reporter class."""
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output_format: OutputFormat = "plain_text",
- print_as_submitted: bool = True,
) -> None:
"""Init reporter instance."""
- self.formatter_resolver = formatter_resolver
- self.output_format = output_format
- self.print_as_submitted = print_as_submitted
-
- self.data: list[tuple[Any, Callable[[Any], Report]]] = []
+ super().__init__(formatter_resolver)
self.delayed: list[tuple[Any, Callable[[Any], Report]]] = []
+ self.output_format: OutputFormat = "plain_text"
def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
"""Submit data for the report."""
- if self.print_as_submitted and not delay_print:
- produce_report(
- data_item,
- self.formatter_resolver(data_item),
- fmt="plain_text",
- **kwargs,
- )
-
formatter = _apply_format_parameters(
self.formatter_resolver(data_item), self.output_format, **kwargs
)
@@ -582,51 +545,70 @@ class Reporter:
if delay_print:
self.delayed.append((data_item, formatter))
+ else:
+ self.produce_report(
+ data_item,
+ self.formatter_resolver(data_item),
+ **kwargs,
+ )
def print_delayed(self) -> None:
"""Print delayed reports."""
- if not self.delayed:
- return
+ if self.delayed:
+ data, formatters = zip(*self.delayed)
+ self.produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ )
+ self.delayed = []
- data, formatters = zip(*self.delayed)
- produce_report(
- data,
- formatter=CompoundFormatter(formatters),
- fmt="plain_text",
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ logger.info(formatted_data.to_plain_text(**kwargs))
+
+
+class JSONReporter(Reporter):
+ """Reporter class."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ ) -> None:
+ """Init reporter instance."""
+ super().__init__(formatter_resolver)
+ self.output_format: OutputFormat = "json"
+
+ def submit(self, data_item: Any, **kwargs: Any) -> None:
+ """Submit data for the report."""
+ formatter = _apply_format_parameters(
+ self.formatter_resolver(data_item), self.output_format, **kwargs
)
- self.delayed = []
+ self.data.append((data_item, formatter))
- def generate_report(self, output: PathOrFileLike | None) -> None:
+ def generate_report(self) -> None:
"""Generate report."""
- already_printed = (
- self.print_as_submitted
- and self.output_format == "plain_text"
- and output is None
- )
- if not self.data or already_printed:
+ if not self.data:
return
data, formatters = zip(*self.data)
- produce_report(
+ self.produce_report(
data,
formatter=CompoundFormatter(formatters),
- fmt=self.output_format,
- output=output,
)
-
-@contextmanager
-def get_reporter(
- output_format: OutputFormat,
- output: PathOrFileLike | None,
- formatter_resolver: Callable[[Any], Callable[[Any], Report]],
-) -> Generator[Reporter, None, None]:
- """Get reporter and generate report."""
- reporter = Reporter(formatter_resolver, output_format)
-
- yield reporter
-
- reporter.generate_report(output)
+ def produce_report(
+ self, data: Any, formatter: Callable[[Any], Report], **kwargs: Any
+ ) -> None:
+ """Produce report based on provided data."""
+ formatted_data = formatter(data)
+ print(
+ json.dumps(
+ formatted_data.to_json(**kwargs), indent=4, cls=CustomJSONEncoder
+ ),
+ )
def _apply_format_parameters(
diff --git a/src/mlia/core/typing.py b/src/mlia/core/typing.py
index 8244ff5..ea334c9 100644
--- a/src/mlia/core/typing.py
+++ b/src/mlia/core/typing.py
@@ -1,12 +1,7 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for custom type hints."""
-from pathlib import Path
from typing import Literal
-from typing import TextIO
-from typing import Union
-FileLike = TextIO
-PathOrFileLike = Union[str, Path, FileLike]
OutputFormat = Literal["plain_text", "json"]
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
index d862a86..9f8ac83 100644
--- a/src/mlia/core/workflow.py
+++ b/src/mlia/core/workflow.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for executors.
@@ -198,6 +198,7 @@ class DefaultWorkflowExecutor(WorkflowExecutor):
self.collectors,
self.analyzers,
self.producers,
+ self.context.event_handlers or [],
)
if isinstance(comp, ContextMixin)
)
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index b649f0d..1249d93 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -67,12 +66,11 @@ def configure_and_get_cortexa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
if context.event_handlers is None:
- context.event_handlers = [CortexAEventHandler(output)]
+ context.event_handlers = [CortexAEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(model, target_profile)
diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py
index d6acde5..1a74da7 100644
--- a/src/mlia/target/cortex_a/handlers.py
+++ b/src/mlia/target/cortex_a/handlers.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import logging
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
@@ -20,9 +19,9 @@ logger = logging.getLogger(__name__)
class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(cortex_a_formatters, output)
+ super().__init__(cortex_a_formatters)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 640c3e1..ce4e0fc 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -126,12 +125,11 @@ def configure_and_get_ethosu_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Ethos-U advisor."""
if context.event_handlers is None:
- context.event_handlers = [EthosUEventHandler(output)]
+ context.event_handlers = [EthosUEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(
diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py
index 91f6015..9873014 100644
--- a/src/mlia/target/ethos_u/handlers.py
+++ b/src/mlia/target/ethos_u/handlers.py
@@ -6,7 +6,6 @@ from __future__ import annotations
import logging
from mlia.backend.vela.compat import Operators
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.target.ethos_u.events import EthosUAdvisorEventHandler
@@ -21,9 +20,9 @@ logger = logging.getLogger(__name__)
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(ethos_u_formatters, output)
+ super().__init__(ethos_u_formatters)
def on_collected_data(self, event: CollectedDataEvent) -> None:
"""Handle CollectedDataEvent event."""
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 0da44db..b60e824 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -10,7 +10,6 @@ from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
-from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
@@ -81,12 +80,11 @@ def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
if context.event_handlers is None:
- context.event_handlers = [TOSAEventHandler(output)]
+ context.event_handlers = [TOSAEventHandler()]
if context.config_parameters is None:
context.config_parameters = _get_config_parameters(model, target_profile)
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
index 7f80f77..f222823 100644
--- a/src/mlia/target/tosa/handlers.py
+++ b/src/mlia/target/tosa/handlers.py
@@ -7,7 +7,6 @@ from __future__ import annotations
import logging
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
-from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
from mlia.target.tosa.events import TOSAAdvisorEventHandler
@@ -20,9 +19,9 @@ logger = logging.getLogger(__name__)
class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
"""Event handler for TOSA advisor."""
- def __init__(self, output: FormattedFilePath | None = None) -> None:
+ def __init__(self) -> None:
"""Init event handler."""
- super().__init__(tosa_formatters, output)
+ super().__init__(tosa_formatters)
def on_tosa_advisor_started(self, event: TOSAAdvisorStartedEvent) -> None:
"""Handle TOSAAdvisorStartedEvent event."""
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
index 0659dcf..17f6cae 100644
--- a/src/mlia/utils/logging.py
+++ b/src/mlia/utils/logging.py
@@ -18,6 +18,8 @@ from typing import Generator
from typing import Iterable
from typing import TextIO
+from mlia.utils.console import remove_ascii_codes
+
class LoggerWriter:
"""Redirect printed messages to the logger."""
@@ -152,12 +154,21 @@ class LogFilter(logging.Filter):
return cls(skip_by_level)
+class NoASCIIFormatter(logging.Formatter):
+ """Custom Formatter for logging into file."""
+
+ def format(self, record: logging.LogRecord) -> str:
+ """Overwrite format method to remove ascii codes from record."""
+ result = super().format(record)
+ return remove_ascii_codes(result)
+
+
def create_log_handler(
*,
file_path: Path | None = None,
stream: Any | None = None,
log_level: int | None = None,
- log_format: str | None = None,
+ log_format: str | logging.Formatter | None = None,
log_filter: logging.Filter | None = None,
delay: bool = True,
) -> logging.Handler:
@@ -176,7 +187,9 @@ def create_log_handler(
handler.setLevel(log_level)
if log_format:
- handler.setFormatter(logging.Formatter(log_format))
+ if isinstance(log_format, str):
+ log_format = logging.Formatter(log_format)
+ handler.setFormatter(log_format)
if log_filter:
handler.addFilter(log_filter)
diff --git a/tests/test_api.py b/tests/test_api.py
index 0bbc3ae..251d5ac 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -20,7 +20,11 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
with pytest.raises(Exception, match="Target profile is not provided"):
- get_advice(None, test_keras_model, {"compatibility"}) # type: ignore
+ get_advice(
+ None, # type:ignore
+ test_keras_model,
+ {"compatibility"},
+ )
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
diff --git a/tests/test_backend_tosa_compat.py b/tests/test_backend_tosa_compat.py
index 5a80b4b..0b6eaf5 100644
--- a/tests/test_backend_tosa_compat.py
+++ b/tests/test_backend_tosa_compat.py
@@ -27,7 +27,7 @@ def replace_get_tosa_checker_with_mock(
def test_compatibility_check_should_fail_if_checker_not_available(
- monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: str | Path
) -> None:
"""Test that compatibility check should fail if TOSA checker is not available."""
replace_get_tosa_checker_with_mock(monkeypatch, None)
@@ -71,7 +71,7 @@ def test_compatibility_check_should_fail_if_checker_not_available(
)
def test_get_tosa_compatibility_info(
monkeypatch: pytest.MonkeyPatch,
- test_tflite_model: Path,
+ test_tflite_model: str | Path,
is_tosa_compatible: bool,
operators: Any,
exception: Exception | None,
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 5a9c0c9..9db5341 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import argparse
from functools import wraps
-from pathlib import Path
from typing import Any
from typing import Callable
from unittest.mock import ANY
@@ -122,8 +121,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.tflite",
compatibility=False,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
@@ -135,8 +132,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.tflite",
compatibility=False,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
@@ -153,8 +148,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- output=None,
- json=False,
compatibility=True,
performance=True,
backend=None,
@@ -167,9 +160,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
"--performance",
"--target-profile",
"ethos-u55-256",
- "--output",
- "result.json",
- "--json",
],
call(
ctx=ANY,
@@ -177,8 +167,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
performance=True,
compatibility=False,
- output=Path("result.json"),
- json=True,
backend=None,
),
],
@@ -196,8 +184,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
compatibility=False,
performance=True,
- output=None,
- json=False,
backend=None,
),
],
@@ -218,8 +204,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=None,
clustering_target=None,
- output=None,
- json=False,
backend=None,
),
],
@@ -244,8 +228,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=True,
pruning_target=0.5,
clustering_target=32,
- output=None,
- json=False,
backend=None,
),
],
@@ -267,8 +249,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
clustering=False,
pruning_target=None,
clustering_target=None,
- output=None,
- json=False,
backend=["some_backend"],
),
],
@@ -286,8 +266,6 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
model="sample_model.h5",
compatibility=True,
performance=False,
- output=None,
- json=False,
backend=None,
),
],
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index a889a93..94c3111 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -5,16 +5,14 @@ from __future__ import annotations
import argparse
from contextlib import ExitStack as does_not_raise
-from pathlib import Path
from typing import Any
import pytest
-from mlia.cli.options import add_output_options
+from mlia.cli.options import get_output_format
from mlia.cli.options import get_target_profile_opts
from mlia.cli.options import parse_optimization_parameters
-from mlia.cli.options import parse_output_parameters
-from mlia.core.common import FormattedFilePath
+from mlia.core.typing import OutputFormat
@pytest.mark.parametrize(
@@ -164,54 +162,24 @@ def test_get_target_opts(args: dict | None, expected_opts: list[str]) -> None:
@pytest.mark.parametrize(
- "output_parameters, expected_path",
- [
- [["--output", "report.json"], "report.json"],
- [["--output", "REPORT.JSON"], "REPORT.JSON"],
- [["--output", "some_folder/report.json"], "some_folder/report.json"],
- ],
-)
-def test_output_options(output_parameters: list[str], expected_path: str) -> None:
- """Test output options resolving."""
- parser = argparse.ArgumentParser()
- add_output_options(parser)
-
- args = parser.parse_args(output_parameters)
- assert str(args.output) == expected_path
-
-
-@pytest.mark.parametrize(
- "path, json, expected_error, output",
+ "args, expected_output_format",
[
[
- None,
- True,
- pytest.raises(
- argparse.ArgumentError,
- match=r"To enable JSON output you need to specify the output path. "
- r"\(e.g. --output out.json --json\)",
- ),
- None,
+ {},
+ "plain_text",
],
- [None, False, does_not_raise(), None],
[
- Path("test_path"),
- False,
- does_not_raise(),
- FormattedFilePath(Path("test_path"), "plain_text"),
+ {"json": True},
+ "json",
],
[
- Path("test_path"),
- True,
- does_not_raise(),
- FormattedFilePath(Path("test_path"), "json"),
+ {"json": False},
+ "plain_text",
],
],
)
-def test_parse_output_parameters(
- path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None
-) -> None:
- """Test parsing for output parameters."""
- with expected_error:
- formatted_output = parse_output_parameters(path, json)
- assert formatted_output == output
+def test_get_output_format(args: dict, expected_output_format: OutputFormat) -> None:
+ """Test get_output_format function."""
+ arguments = argparse.Namespace(**args)
+ output_format = get_output_format(arguments)
+ assert output_format == expected_output_format
diff --git a/tests/test_core_context.py b/tests/test_core_context.py
index dcdbef3..0e7145f 100644
--- a/tests/test_core_context.py
+++ b/tests/test_core_context.py
@@ -58,6 +58,7 @@ def test_execution_context(tmpdir: str) -> None:
verbose=True,
logs_dir="logs_directory",
models_dir="models_directory",
+ output_format="json",
)
assert context.advice_category == category
@@ -68,12 +69,14 @@ def test_execution_context(tmpdir: str) -> None:
expected_model_path = Path(tmpdir) / "models_directory/sample.model"
assert context.get_model_path("sample.model") == expected_model_path
assert context.verbose is True
+ assert context.output_format == "json"
assert str(context) == (
f"ExecutionContext: "
f"working_dir={tmpdir}, "
"advice_category={'COMPATIBILITY'}, "
"config_parameters={'param': 'value'}, "
- "verbose=True"
+ "verbose=True, "
+ "output_format=json"
)
context_with_default_params = ExecutionContext(working_dir=tmpdir)
@@ -88,11 +91,13 @@ def test_execution_context(tmpdir: str) -> None:
default_model_path = context_with_default_params.get_model_path("sample.model")
expected_default_model_path = Path(tmpdir) / "models/sample.model"
assert default_model_path == expected_default_model_path
+ assert context_with_default_params.output_format == "plain_text"
expected_str = (
f"ExecutionContext: working_dir={tmpdir}, "
"advice_category={'COMPATIBILITY'}, "
"config_parameters=None, "
- "verbose=False"
+ "verbose=False, "
+ "output_format=plain_text"
)
assert str(context_with_default_params) == expected_str
diff --git a/tests/test_cli_logging.py b/tests/test_core_logging.py
index 1e2cc85..e021e26 100644
--- a/tests/test_cli_logging.py
+++ b/tests/test_core_logging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module cli.logging."""
from __future__ import annotations
@@ -8,7 +8,7 @@ from pathlib import Path
import pytest
-from mlia.cli.logging import setup_logging
+from mlia.core.logging import setup_logging
from tests.utils.logging import clear_loggers
@@ -33,20 +33,21 @@ def teardown_function() -> None:
(
None,
True,
- """mlia.backend.manager - backends debug
-cli info
-mlia.cli - cli debug
+ """mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
+mlia.cli - DEBUG - cli debug
""",
None,
),
(
"logs",
True,
- """mlia.backend.manager - backends debug
-cli info
-mlia.cli - cli debug
+ """mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
+mlia.cli - DEBUG - cli debug
""",
"""mlia.backend.manager - DEBUG - backends debug
+mlia.cli - INFO - cli info
mlia.cli - DEBUG - cli debug
""",
),
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index 71eaf85..7a68d4b 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -3,9 +3,13 @@
"""Tests for reporting module."""
from __future__ import annotations
-import io
import json
from enum import Enum
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+from unittest.mock import Mock
+from unittest.mock import patch
import numpy as np
import pytest
@@ -14,13 +18,15 @@ from mlia.core.reporting import BytesCell
from mlia.core.reporting import Cell
from mlia.core.reporting import ClockCell
from mlia.core.reporting import Column
+from mlia.core.reporting import CustomJSONEncoder
from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
-from mlia.core.reporting import json_reporter
+from mlia.core.reporting import JSONReporter
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
+from mlia.core.reporting import TextReporter
from mlia.utils.console import remove_ascii_codes
@@ -364,10 +370,9 @@ def test_custom_json_serialization() -> None:
alias="sample_table",
)
- output = io.StringIO()
- json_reporter(table, output)
+ output = json.dumps(table.to_json(), indent=4, cls=CustomJSONEncoder)
- assert json.loads(output.getvalue()) == {
+ assert json.loads(output) == {
"sample_table": [
{"column1": "value1"},
{"column1": 10.0},
@@ -375,3 +380,93 @@ def test_custom_json_serialization() -> None:
{"column1": 10},
]
}
+
+
+class TestTextReporter:
+ """Test TextReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test TextReporter."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ assert reporter.output_format == "plain_text"
+
+ def test_submit(self) -> None:
+ """Test TextReporter submit."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2", delay_print=True)
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+ assert reporter.delayed == [("test2", ANY)]
+
+ def test_print_delayed(self) -> None:
+ """Test TextReporter print_delayed."""
+ with patch(
+ "mlia.core.reporting.TextReporter.produce_report"
+ ) as mock_produce_report:
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+ reporter.submit("test", delay_print=True)
+ reporter.print_delayed()
+ assert reporter.data == [("test", ANY)]
+ assert not reporter.delayed
+ mock_produce_report.assert_called()
+
+ def test_produce_report(self) -> None:
+ """Test TextReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = TextReporter(format_resolver)
+
+ with patch("mlia.core.reporting.logger") as mock_logger:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_plain_text()])
+ mock_logger.info.assert_called()
+
+
+class TestJSONReporter:
+ """Test JSONReporter methods."""
+
+ def test_text_reporter(self) -> None:
+ """Test JSONReporter."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ assert reporter.output_format == "json"
+
+ def test_submit(self) -> None:
+ """Test JSONReporter submit."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+ assert reporter.data == [("test", ANY)]
+
+ reporter.submit("test2")
+ assert reporter.data == [("test", ANY), ("test2", ANY)]
+
+ def test_generate_report(self) -> None:
+ """Test JSONReporter generate_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+ reporter.submit("test")
+
+ with patch(
+ "mlia.core.reporting.JSONReporter.produce_report"
+ ) as mock_produce_report:
+ reporter.generate_report()
+ mock_produce_report.assert_called()
+
+ @patch("builtins.print")
+ def test_produce_report(self, mock_print: Mock) -> None:
+ """Test JSONReporter produce_report."""
+ format_resolver = MagicMock()
+ reporter = JSONReporter(format_resolver)
+
+ with patch("json.dumps") as mock_dumps:
+ mock_formatter = MagicMock()
+ reporter.produce_report("test", mock_formatter)
+ mock_formatter.assert_has_calls([call("test"), call().to_json()])
+ mock_dumps.assert_called()
+ mock_print.assert_called()
diff --git a/tests/test_target_ethos_u_reporters.py b/tests/test_target_ethos_u_reporters.py
index 7f372bf..ee7ea52 100644
--- a/tests/test_target_ethos_u_reporters.py
+++ b/tests/test_target_ethos_u_reporters.py
@@ -1,106 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reports module."""
from __future__ import annotations
-import json
-import sys
-from contextlib import ExitStack as doesnt_raise
-from pathlib import Path
-from typing import Any
-from typing import Callable
-from typing import Literal
-
import pytest
from mlia.backend.vela.compat import NpuSupported
from mlia.backend.vela.compat import Operator
-from mlia.backend.vela.compat import Operators
-from mlia.core.reporting import get_reporter
-from mlia.core.reporting import produce_report
from mlia.core.reporting import Report
-from mlia.core.reporting import Reporter
from mlia.core.reporting import Table
from mlia.target.ethos_u.config import EthosUConfiguration
-from mlia.target.ethos_u.performance import MemoryUsage
-from mlia.target.ethos_u.performance import NPUCycles
-from mlia.target.ethos_u.performance import PerformanceMetrics
-from mlia.target.ethos_u.reporters import ethos_u_formatters
from mlia.target.ethos_u.reporters import report_device_details
from mlia.target.ethos_u.reporters import report_operators
-from mlia.target.ethos_u.reporters import report_perf_metrics
from mlia.utils.console import remove_ascii_codes
@pytest.mark.parametrize(
- "data, formatters",
- [
- (
- [Operator("test_operator", "test_type", NpuSupported(False, []))],
- [report_operators],
- ),
- (
- PerformanceMetrics(
- EthosUConfiguration("ethos-u55-256"),
- NPUCycles(0, 0, 0, 0, 0, 0),
- MemoryUsage(0, 0, 0, 0, 0),
- ),
- [report_perf_metrics],
- ),
- ],
-)
-@pytest.mark.parametrize(
- "fmt, output, expected_error",
- [
- [
- "unknown_format",
- sys.stdout,
- pytest.raises(Exception, match="Unknown format unknown_format"),
- ],
- [
- "plain_text",
- sys.stdout,
- doesnt_raise(),
- ],
- [
- "json",
- sys.stdout,
- doesnt_raise(),
- ],
- [
- "plain_text",
- "report.txt",
- doesnt_raise(),
- ],
- [
- "json",
- "report.json",
- doesnt_raise(),
- ],
- ],
-)
-def test_report(
- data: Any,
- formatters: list[Callable],
- fmt: Literal["plain_text", "json"],
- output: Any,
- expected_error: Any,
- tmp_path: Path,
-) -> None:
- """Test report function."""
- if is_file := isinstance(output, str):
- output = tmp_path / output
-
- for formatter in formatters:
- with expected_error:
- produce_report(data, formatter, fmt, output)
-
- if is_file:
- assert output.is_file()
- assert output.stat().st_size > 0
-
-
-@pytest.mark.parametrize(
"ops, expected_plain_text, expected_json_dict",
[
(
@@ -314,40 +229,3 @@ def test_report_device_details(
json_dict = report.to_json()
assert json_dict == expected_json_dict
-
-
-def test_get_reporter(tmp_path: Path) -> None:
- """Test reporter functionality."""
- ops = Operators(
- [
- Operator(
- "npu_supported",
- "op_type",
- NpuSupported(True, []),
- ),
- ]
- )
-
- output = tmp_path / "output.json"
- with get_reporter("json", output, ethos_u_formatters) as reporter:
- assert isinstance(reporter, Reporter)
-
- with pytest.raises(
- Exception, match="Unable to find appropriate formatter for some_data"
- ):
- reporter.submit("some_data")
-
- reporter.submit(ops)
-
- with open(output, encoding="utf-8") as file:
- json_data = json.load(file)
-
- assert json_data == {
- "operators_stats": [
- {
- "npu_unsupported_ratio": 0.0,
- "num_of_npu_supported_operators": 1,
- "num_of_operators": 1,
- }
- ]
- }
diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py
index 35fd707..beddaed 100644
--- a/tests_e2e/test_e2e.py
+++ b/tests_e2e/test_e2e.py
@@ -16,6 +16,7 @@ from pathlib import Path
from typing import Any
from typing import Generator
from typing import Iterable
+from typing import Sequence
import pytest
@@ -75,30 +76,48 @@ class ExecutionConfiguration:
)
-def launch_and_wait(cmd: list[str], stdin: Any | None = None) -> None:
+def launch_and_wait(
+ cmd: list[str],
+ output_file: Path | None = None,
+ print_output: bool = True,
+ stdin: Any | None = None,
+) -> None:
"""Launch command and wait for the completion."""
with subprocess.Popen( # nosec
cmd,
stdin=stdin,
stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT, # redirect command stderr to stdout
+ stderr=subprocess.PIPE,
universal_newlines=True,
bufsize=1,
) as process:
- if process.stdout is None:
+ # Process stdout
+ if process.stdout:
+ # Store the output in a variable
+ output = process.stdout.read()
+ # Save the output into a file
+ if output_file:
+ output_file.write_text(output)
+ print(f"Output saved to {output_file}")
+ # Show the output to stdout
+ if print_output:
+ print(output)
+ else:
raise Exception("Unable to get process output")
- # redirect output of the process into current process stdout
- for line in process.stdout:
- print(line, end="")
-
+ # Wait for the process to terminate
process.wait()
if (exit_code := process.poll()) != 0:
raise Exception(f"Command failed with exit_code {exit_code}")
-def run_command(cmd: list[str], cmd_input: str | None = None) -> None:
+def run_command(
+ cmd: list[str],
+ output_file: Path | None = None,
+ print_output: bool = True,
+ cmd_input: str | None = None,
+) -> None:
"""Run command."""
print(f"Run command: {' '.join(cmd)}")
@@ -118,7 +137,7 @@ def run_command(cmd: list[str], cmd_input: str | None = None) -> None:
cmd_input_file.write(cmd_input)
cmd_input_file.seek(0)
- launch_and_wait(cmd, cmd_input_file)
+ launch_and_wait(cmd, output_file, print_output, cmd_input_file)
def get_config_file() -> Path:
@@ -209,10 +228,15 @@ def get_config_content(config_file: Path) -> Any:
executions = json_data.get("executions", [])
assert is_list_of(executions, dict), "List of the dictionaries expected"
- return executions
+ settings = json_data.get("settings", {})
+ assert isinstance(settings, dict)
+
+ return settings, executions
-def get_all_commands_combinations(executions: Any) -> Generator[list[str], None, None]:
+def get_all_commands_combinations(
+ executions: Any,
+) -> Generator[dict[str, Sequence[str]], None, None]:
"""Return all commands combinations."""
exec_configs = (
ExecutionConfiguration.from_dict(exec_info) for exec_info in executions
@@ -221,13 +245,12 @@ def get_all_commands_combinations(executions: Any) -> Generator[list[str], None,
parser = get_args_parser()
for exec_config in exec_configs:
for command_combination in exec_config.all_combinations:
- for idx, param in enumerate(command_combination):
- if "{model_name}" in param:
- args = parser.parse_args(command_combination)
- model_name = Path(args.model).stem
- param = param.replace("{model_name}", model_name)
- command_combination[idx] = param
- yield command_combination
+ args = parser.parse_args(command_combination)
+ model_name = Path(args.model).stem
+ yield {
+ "model_name": model_name,
+ "command_combination": command_combination,
+ }
def check_args(args: list[str], no_skip: bool) -> None:
@@ -249,21 +272,31 @@ def check_args(args: list[str], no_skip: bool) -> None:
pytest.skip(f"Missing backend(s): {','.join(missing_backends)}")
-def get_execution_definitions() -> Generator[list[str], None, None]:
+def get_execution_definitions(
+ executions: dict,
+) -> Generator[dict[str, Sequence[str]], None, None]:
"""Collect all execution definitions from configuration file."""
- config_file = get_config_file()
- executions = get_config_content(config_file)
- executions = resolve_parameters(executions)
-
- return get_all_commands_combinations(executions)
+ resolved_executions = resolve_parameters(executions)
+ return get_all_commands_combinations(resolved_executions)
class TestEndToEnd:
"""End to end command tests."""
- @pytest.mark.parametrize("command", get_execution_definitions(), ids=str)
- def test_e2e(self, command: list[str], no_skip: bool) -> None:
+ configuration_file = get_config_file()
+ settings, executions = get_config_content(configuration_file)
+
+ @pytest.mark.parametrize(
+ "command_data", get_execution_definitions(executions), ids=str
+ )
+ def test_e2e(self, command_data: dict[str, list[str]], no_skip: bool) -> None:
"""Test MLIA command with the provided parameters."""
+ command = command_data["command_combination"]
+ model_name = command_data["model_name"]
check_args(command, no_skip)
mlia_command = ["mlia", *command]
- run_command(mlia_command)
+ print_output = self.settings.get("print_output", True)
+ output_file = self.settings.get("output_file", None)
+ if output_file:
+ output_file = Path(output_file.replace("{model_name}", model_name))
+ run_command(mlia_command, output_file, print_output)