From 5800fc990ed1e36ce7d06670f911fbb12a0ec771 Mon Sep 17 00:00:00 2001 From: Raul Farkas Date: Tue, 29 Nov 2022 13:29:04 +0000 Subject: MLIA-650 Implement new CLI changes Breaking change in the CLI and API: Sub-commands "optimization", "operators", and "performance" were replaced by "check", which incorporates compatibility and performance checks, and "optimize" which is used for optimization. "get_advice" API was adapted to these CLI changes. API changes: * Remove previous advice category "all" that would perform all three operations (when possible). Replace them with the ability to pass a set of the advice categories. * Update api.get_advice method docstring to reflect new changes. * Set default advice category to COMPATIBILITY * Update core.common.AdviceCategory by changing the "OPERATORS" advice category to "COMPATIBILITY" and removing "ALL" enum type. Update all subsequent methods that previously used "OPERATORS" to use "COMPATIBILITY". * Update core.context.ExecutionContext to have "COMPATIBILITY" as default advice_category instead of "ALL". * Remove api.generate_supported_operators_report and all related functions from cli.commands, cli.helpers, cli.main, cli.options, core.helpers * Update tests to reflect new API changes. CLI changes: * Update README.md to contain information on the new CLI * Remove the ability to generate supported operators support from MLIA CLI * Replace `mlia ops` and `mlia perf` with the new `mlia check` command that can be used to perform both operations. * Replace `mlia opt` with the new `mlia optimize` command. * Replace `--evaluate-on` flag with `--backend` flag * Replace `--verbose` flag with `--debug` flag (no behaviour change). * Remove the ability for the user to select MLIA working directory. Create and use a temporary directory in /temp instead. * Change behaviour of `--output` flag to not format the content automatically based on file extension anymore. Instead it will simply redirect to a file. * Add the `--json` flag to specfy that the format of the output should be json. * Add command validators that are used to validate inter-dependent flags (e.g. backend validation based on target_profile). * Add support for selecting built-in backends for both `check` and `optimize` commands. * Add new unit tests and update old ones to test the new CLI changes. * Update RELEASES.md * Update copyright notice Change-Id: Ia6340797c7bee3acbbd26601950e5a16ad5602db --- README.md | 74 +++---- RELEASES.md | 8 +- src/mlia/api.py | 57 ++---- src/mlia/backend/armnn_tflite_delegate/__init__.py | 4 +- src/mlia/backend/tosa_checker/__init__.py | 4 +- src/mlia/backend/vela/__init__.py | 4 +- src/mlia/cli/command_validators.py | 113 ++++++++++ src/mlia/cli/commands.py | 208 +++++++------------ src/mlia/cli/config.py | 38 +++- src/mlia/cli/helpers.py | 36 ++-- src/mlia/cli/main.py | 81 +++----- src/mlia/cli/options.py | 199 ++++++++++-------- src/mlia/core/common.py | 53 +++-- src/mlia/core/context.py | 42 ++-- src/mlia/core/handlers.py | 10 +- src/mlia/core/helpers.py | 6 +- src/mlia/core/reporting.py | 13 +- src/mlia/target/cortex_a/advice_generation.py | 12 +- src/mlia/target/cortex_a/advisor.py | 8 +- src/mlia/target/cortex_a/handlers.py | 6 +- src/mlia/target/ethos_u/advice_generation.py | 34 +-- src/mlia/target/ethos_u/advisor.py | 14 +- src/mlia/target/ethos_u/handlers.py | 6 +- src/mlia/target/tosa/advice_generation.py | 6 +- src/mlia/target/tosa/advisor.py | 8 +- src/mlia/target/tosa/handlers.py | 6 +- src/mlia/utils/types.py | 4 +- tests/test_api.py | 98 +++------ tests/test_backend_config.py | 12 +- tests/test_backend_registry.py | 8 +- tests/test_cli_command_validators.py | 167 +++++++++++++++ tests/test_cli_commands.py | 97 +++------ tests/test_cli_config.py | 8 +- tests/test_cli_helpers.py | 62 ++---- tests/test_cli_main.py | 228 +++++++++++---------- tests/test_cli_options.py | 179 +++++++++------- tests/test_core_advice_generation.py | 10 +- tests/test_core_context.py | 46 ++++- tests/test_core_helpers.py | 3 +- tests/test_core_mixins.py | 6 +- tests/test_core_reporting.py | 22 +- tests/test_target_config.py | 6 +- tests/test_target_cortex_a_advice_generation.py | 18 +- tests/test_target_ethos_u_advice_generation.py | 70 +++---- tests/test_target_registry.py | 12 +- tests/test_target_tosa_advice_generation.py | 8 +- tests_e2e/test_e2e.py | 21 +- 47 files changed, 1155 insertions(+), 980 deletions(-) create mode 100644 src/mlia/cli/command_validators.py create mode 100644 tests/test_cli_command_validators.py diff --git a/README.md b/README.md index 501c8c5..d163728 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # ML Inference Advisor - Introduction @@ -96,10 +96,8 @@ mlia [sub-command] [arguments] Where the following sub-commands are available: -* ["operators"](#operators-ops): show the model's operator list -* ["optimization"](#model-optimization-opt): run the specified optimizations -* ["performance"](#performance-perf): measure the performance of inference on hardware -* ["all_tests"](#all-tests-all): have a full report +* ["check"](#check): perform compatibility or performance checks on the model +* ["optimize"](#optimize): apply specified optimizations Detailed help about the different sub-commands can be shown like this: @@ -113,25 +111,27 @@ The following sections go into further detail regarding the usage of MLIA. This section gives an overview of the available sub-commands for MLIA. -## **operators** (ops) +## **check** -Lists the model's operators with information about their compatibility with the -specified target. +### compatibility + +Default check that MLIA runs. It lists the model's operators with information +about their compatibility with the specified target. *Examples:* ```bash # List operator compatibility with Ethos-U55 with 256 MAC -mlia operators --target-profile ethos-u55-256 ~/models/mobilenet_v1_1.0_224_quant.tflite +mlia check ~/models/mobilenet_v1_1.0_224_quant.tflite --target-profile ethos-u55-256 # List operator compatibility with Cortex-A -mlia ops --target-profile cortex-a ~/models/mobilenet_v1_1.0_224_quant.tflite +mlia check ~/models/mobilenet_v1_1.0_224_quant.tflite --target-profile cortex-a # Get help and further information -mlia ops --help +mlia check --help ``` -## **performance** (perf) +### performance Estimate the model's performance on the specified target and print out statistics. @@ -140,18 +140,21 @@ statistics. ```bash # Use default parameters -mlia performance ~/models/mobilenet_v1_1.0_224_quant.tflite +mlia check ~/models/mobilenet_v1_1.0_224_quant.tflite \ + --target-profile ethos-u55-256 \ + --performance -# Explicitly specify the target profile and backend(s) to use with --evaluate-on -mlia perf ~/models/ds_cnn_large_fully_quantized_int8.tflite \ - --evaluate-on "Vela" "Corstone-310" \ - --target-profile ethos-u65-512 +# Explicitly specify the target profile and backend(s) to use with --backend +mlia check ~/models/ds_cnn_large_fully_quantized_int8.tflite \ + --target-profile ethos-u65-512 \ + --performance \ + --backend "Vela" "Corstone-310" # Get help and further information -mlia perf --help +mlia check --help ``` -## **optimization** (opt) +## **optimize** This sub-command applies optimizations to a Keras model (.h5 or SavedModel) and shows the performance improvements compared to the original unoptimized model. @@ -175,35 +178,20 @@ supported. ```bash # Custom optimization parameters: pruning=0.6, clustering=16 -mlia optimization \ - --optimization-type pruning,clustering \ - --optimization-target 0.6,16 \ - ~/models/ds_cnn_l.h5 - -# Get help and further information -mlia opt --help -``` - -## **all_tests** (all) - -Combine sub-commands described above to generate a full report of the input -model with all information available for the specified target. E.g. for Ethos-U -this combines sub-commands *operators* and *optimization*. Therefore most -command line arguments are shared with other sub-commands. - -*Examples:* - -```bash -# Create full report and save it as JSON file -mlia all_tests --output ./report.json ~/models/ds_cnn_l.h5 +mlia optimize ~/models/ds_cnn_l.h5 \ + --target-profile ethos-u55-256 \ + --pruning \ + --pruning-target 0.6 \ + --clustering \ + --clustering-target 16 # Get help and further information -mlia all --help +mlia optimize --help ``` # Target profiles -Most sub-commands accept the name of a target profile as input parameter. The +All sub-commands require the name of a target profile as input parameter. The profiles currently available are described in the following sections. The support of the above sub-commands for different targets is provided via @@ -232,7 +220,7 @@ attributes: Example: ```bash -mlia perf --target-profile ethos-u65-512 ~/model.tflite +mlia check ~/model.tflite --target-profile ethos-u65-512 --performance ``` Ethos-U is supported by these backends: diff --git a/RELEASES.md b/RELEASES.md index 7f4c752..4d04c89 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,5 @@ # MLIA Releases @@ -16,6 +16,12 @@ scheme. of Arm® Limited (or its subsidiaries) in the U.S. and/or elsewhere. * TensorFlow™ is a trademark of Google® LLC. +## Release 0.6.0 + +### Interface changes + +* **Breaking change:** Implement new CLI changes (MLIA-650) + ## Release 0.5.0 ### Feature changes diff --git a/src/mlia/api.py b/src/mlia/api.py index c7be9ec..2cabf37 100644 --- a/src/mlia/api.py +++ b/src/mlia/api.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the API functions.""" from __future__ import annotations @@ -6,18 +6,14 @@ from __future__ import annotations import logging from pathlib import Path from typing import Any -from typing import Literal from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory +from mlia.core.common import FormattedFilePath from mlia.core.context import ExecutionContext -from mlia.core.typing import PathOrFileLike from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor -from mlia.target.cortex_a.operators import report as cortex_a_report from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor -from mlia.target.ethos_u.operators import report as ethos_u_report from mlia.target.tosa.advisor import configure_and_get_tosa_advisor -from mlia.target.tosa.operators import report as tosa_report from mlia.utils.filesystem import get_target logger = logging.getLogger(__name__) @@ -26,10 +22,9 @@ logger = logging.getLogger(__name__) def get_advice( target_profile: str, model: str | Path, - category: Literal["all", "operators", "performance", "optimization"] = "all", + category: set[str], optimization_targets: list[dict[str, Any]] | None = None, - working_dir: str | Path = "mlia_output", - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, context: ExecutionContext | None = None, backends: list[str] | None = None, ) -> None: @@ -42,17 +37,13 @@ def get_advice( :param target_profile: target profile identifier :param model: path to the NN model - :param category: category of the advice. MLIA supports four categories: - "all", "operators", "performance", "optimization". If not provided - category "all" is used by default. + :param category: set of categories of the advice. MLIA supports three categories: + "compatibility", "performance", "optimization". If not provided + category "compatibility" is used by default. :param optimization_targets: optional model optimization targets that - could be used for generating advice in categories - "all" and "optimization." - :param working_dir: path to the directory that will be used for storing - intermediate files during execution (e.g. converted models) - :param output: path to the report file. If provided MLIA will save - report in this location. Format of the report automatically - detected based on file extension. + could be used for generating advice in "optimization" category. + :param output: path to the report file. If provided, MLIA will save + report in this location. :param context: optional parameter which represents execution context, could be used for advanced use cases :param backends: A list of backends that should be used for the given @@ -63,13 +54,14 @@ def get_advice( Getting the advice for the provided target profile and the model - >>> get_advice("ethos-u55-256", "path/to/the/model") + >>> get_advice("ethos-u55-256", "path/to/the/model", + {"optimization", "compatibility"}) Getting the advice for the category "performance" and save result report in file "report.json" - >>> get_advice("ethos-u55-256", "path/to/the/model", "performance", - output="report.json") + >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"}, + output=FormattedFilePath("report.json") """ advice_category = AdviceCategory.from_string(category) @@ -78,10 +70,7 @@ def get_advice( context.advice_category = advice_category if context is None: - context = ExecutionContext( - advice_category=advice_category, - working_dir=working_dir, - ) + context = ExecutionContext(advice_category=advice_category) advisor = get_advisor( context, @@ -99,7 +88,7 @@ def get_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, **extra_args: Any, ) -> InferenceAdvisor: """Find appropriate advisor for the target.""" @@ -123,17 +112,3 @@ def get_advisor( output, **extra_args, ) - - -def generate_supported_operators_report(target_profile: str) -> None: - """Generate a supported operators report based on given target profile.""" - generators_map = { - "ethos-u55": ethos_u_report, - "ethos-u65": ethos_u_report, - "cortex-a": cortex_a_report, - "tosa": tosa_report, - } - - target = get_target(target_profile) - - generators_map[target]() diff --git a/src/mlia/backend/armnn_tflite_delegate/__init__.py b/src/mlia/backend/armnn_tflite_delegate/__init__.py index 6d5af42..ccb7e38 100644 --- a/src/mlia/backend/armnn_tflite_delegate/__init__.py +++ b/src/mlia/backend/armnn_tflite_delegate/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Arm NN TensorFlow Lite delegate backend module.""" from mlia.backend.config import BackendConfiguration @@ -9,7 +9,7 @@ from mlia.core.common import AdviceCategory registry.register( "ArmNNTFLiteDelegate", BackendConfiguration( - supported_advice=[AdviceCategory.OPERATORS], + supported_advice=[AdviceCategory.COMPATIBILITY], supported_systems=None, backend_type=BackendType.BUILTIN, ), diff --git a/src/mlia/backend/tosa_checker/__init__.py b/src/mlia/backend/tosa_checker/__init__.py index 19fc8be..c06a122 100644 --- a/src/mlia/backend/tosa_checker/__init__.py +++ b/src/mlia/backend/tosa_checker/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA checker backend module.""" from mlia.backend.config import BackendConfiguration @@ -10,7 +10,7 @@ from mlia.core.common import AdviceCategory registry.register( "TOSA-Checker", BackendConfiguration( - supported_advice=[AdviceCategory.OPERATORS], + supported_advice=[AdviceCategory.COMPATIBILITY], supported_systems=[System.LINUX_AMD64], backend_type=BackendType.WHEEL, ), diff --git a/src/mlia/backend/vela/__init__.py b/src/mlia/backend/vela/__init__.py index 38a623e..68fbcba 100644 --- a/src/mlia/backend/vela/__init__.py +++ b/src/mlia/backend/vela/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Vela backend module.""" from mlia.backend.config import BackendConfiguration @@ -11,7 +11,7 @@ registry.register( "Vela", BackendConfiguration( supported_advice=[ - AdviceCategory.OPERATORS, + AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION, ], diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py new file mode 100644 index 0000000..1974a1d --- /dev/null +++ b/src/mlia/cli/command_validators.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""CLI command validators module.""" +from __future__ import annotations + +import argparse +import logging +import sys + +from mlia.cli.config import get_default_backends +from mlia.target.registry import supported_backends +from mlia.utils.filesystem import get_target + +logger = logging.getLogger(__name__) + + +def validate_backend( + target_profile: str, backend: list[str] | None +) -> list[str] | None: + """Validate backend with given target profile. + + This validator checks whether the given target-profile and backend are + compatible with each other. + It assumes that prior checks where made on the validity of the target-profile. + """ + target_map = { + "ethos-u55": "Ethos-U55", + "ethos-u65": "Ethos-U65", + "cortex-a": "Cortex-A", + "tosa": "TOSA", + } + target = get_target(target_profile) + + if not backend: + return get_default_backends()[target] + + compatible_backends = supported_backends(target_map[target]) + + nor_backend = list(map(normalize_string, backend)) + nor_compat_backend = list(map(normalize_string, compatible_backends)) + + incompatible_backends = [ + backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend + ] + # Throw an error if any unsupported backends are used + if incompatible_backends: + raise argparse.ArgumentError( + None, + f"{', '.join(incompatible_backends)} backend not supported " + f"with target-profile {target_profile}.", + ) + return backend + + +def validate_check_target_profile(target_profile: str, category: set[str]) -> None: + """Validate whether advice category is compatible with the provided target_profile. + + This validator function raises warnings if any desired advice category is not + compatible with the selected target profile. If no operation can be + performed as a result of the validation, MLIA exits with error code 0. + """ + incompatible_targets_performance: list[str] = ["tosa", "cortex-a"] + incompatible_targets_compatibility: list[str] = [] + + # Check which check operation should be performed + try_performance = "performance" in category + try_compatibility = "compatibility" in category + + # Cross check which of the desired operations can be performed on given + # target-profile + do_performance = ( + try_performance and target_profile not in incompatible_targets_performance + ) + do_compatibility = ( + try_compatibility and target_profile not in incompatible_targets_compatibility + ) + + # Case: desired operations can be performed with given target profile + if (try_performance == do_performance) and (try_compatibility == do_compatibility): + return + + warning_message = "\nWARNING: " + # Case: performance operation to be skipped + if try_performance and not do_performance: + warning_message += ( + "Performance checks skipped as they cannot be " + f"performed with target profile {target_profile}." + ) + + # Case: compatibility operation to be skipped + if try_compatibility and not do_compatibility: + warning_message += ( + "Compatibility checks skipped as they cannot be " + f"performed with target profile {target_profile}." + ) + + # Case: at least one operation will be performed + if do_compatibility or do_performance: + logger.warning(warning_message) + return + + # Case: no operation will be performed + warning_message += " No operation was performed." + logger.warning(warning_message) + sys.exit(0) + + +def normalize_string(value: str) -> str: + """Given a string return the normalized version. + + E.g. Given "ToSa-cHecker" -> "tosachecker" + """ + return value.lower().replace("-", "") diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py index 09fe9de..d2242ba 100644 --- a/src/mlia/cli/commands.py +++ b/src/mlia/cli/commands.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI commands module. @@ -13,7 +13,7 @@ be configured. Function 'setup_logging' from module >>> from mlia.cli.logging import setup_logging >>> setup_logging(verbose=True) >>> import mlia.cli.commands as mlia ->>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", +>>> mlia.check(ExecutionContext(), "ethos-u55-256", "path/to/model") """ from __future__ import annotations @@ -22,11 +22,12 @@ import logging from pathlib import Path from mlia.api import ExecutionContext -from mlia.api import generate_supported_operators_report from mlia.api import get_advice -from mlia.api import PathOrFileLike +from mlia.cli.command_validators import validate_backend +from mlia.cli.command_validators import validate_check_target_profile from mlia.cli.config import get_installation_manager from mlia.cli.options import parse_optimization_parameters +from mlia.cli.options import parse_output_parameters from mlia.utils.console import create_section_header logger = logging.getLogger(__name__) @@ -34,14 +35,15 @@ logger = logging.getLogger(__name__) CONFIG = create_section_header("ML Inference Advisor configuration") -def all_tests( +def check( ctx: ExecutionContext, target_profile: str, - model: str, - optimization_type: str = "pruning,clustering", - optimization_target: str = "0.5,32", - output: PathOrFileLike | None = None, - evaluate_on: list[str] | None = None, + model: str | None = None, + compatibility: bool = False, + performance: bool = False, + output: Path | None = None, + json: bool = False, + backend: list[str] | None = None, ) -> None: """Generate a full report on the input model. @@ -50,8 +52,6 @@ def all_tests( - converts the input Keras model into TensorFlow Lite format - checks the model for operator compatibility on the specified device - - applies optimizations to the model and estimates the resulting performance - on both the original and the optimized models - generates a final report on the steps above - provides advice on how to (possibly) improve the inference performance @@ -59,140 +59,63 @@ def all_tests( :param target_profile: target profile identifier. Will load appropriate parameters from the profile.json file based on this argument. :param model: path to the Keras model - :param optimization_type: list of the optimization techniques separated - by comma, e.g. 'pruning,clustering' - :param optimization_target: list of the corresponding targets for - the provided optimization techniques, e.g. '0.5,32' + :param compatibility: flag that identifies whether to run compatibility checks + :param performance: flag that identifies whether to run performance checks :param output: path to the file where the report will be saved - :param evaluate_on: list of the backends to use for evaluation + :param backend: list of the backends to use for evaluation Example: - Run command for the target profile ethos-u55-256 with two model optimizations - and save report in json format locally in the file report.json + Run command for the target profile ethos-u55-256 to verify both performance + and operator compatibility. >>> from mlia.api import ExecutionContext >>> from mlia.cli.logging import setup_logging >>> setup_logging() - >>> from mlia.cli.commands import all_tests - >>> all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", - "model.h5", "pruning,clustering", "0.5,32", + >>> from mlia.cli.commands import check + >>> check(ExecutionContext(), "ethos-u55-256", + "model.h5", compatibility=True, performance=True, output="report.json") """ - opt_params = parse_optimization_parameters( - optimization_type, - optimization_target, - ) - - get_advice( - target_profile, - model, - "all", - optimization_targets=opt_params, - output=output, - context=ctx, - backends=evaluate_on, - ) - - -def operators( - ctx: ExecutionContext, - target_profile: str, - model: str | None = None, - output: PathOrFileLike | None = None, - supported_ops_report: bool = False, -) -> None: - """Print the model's operator list. - - This command checks the operator compatibility of the input model with - the specific target profile. Generates a report of the operator placement - (NPU or CPU fallback) and advice on how to improve it (if necessary). - - :param ctx: execution context - :param target_profile: target profile identifier. Will load appropriate parameters - from the profile.json file based on this argument. - :param model: path to the model, which can be TensorFlow Lite or Keras - :param output: path to the file where the report will be saved - :param supported_ops_report: if True then generates supported operators - report in current directory and exits - - Example: - Run command for the target profile ethos-u55-256 and the provided - TensorFlow Lite model and print report on the standard output - - >>> from mlia.api import ExecutionContext - >>> from mlia.cli.logging import setup_logging - >>> setup_logging() - >>> from mlia.cli.commands import operators - >>> operators(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", - "model.tflite") - """ - if supported_ops_report: - generate_supported_operators_report(target_profile) - logger.info("Report saved into SUPPORTED_OPS.md") - return - if not model: raise Exception("Model is not provided") - get_advice( - target_profile, - model, - "operators", - output=output, - context=ctx, - ) - + formatted_output = parse_output_parameters(output, json) -def performance( - ctx: ExecutionContext, - target_profile: str, - model: str, - output: PathOrFileLike | None = None, - evaluate_on: list[str] | None = None, -) -> None: - """Print the model's performance stats. - - This command estimates the inference performance of the input model - on the specified target profile, and generates a report with advice on how - to improve it. - - :param ctx: execution context - :param target_profile: target profile identifier. Will load appropriate parameters - from the profile.json file based on this argument. - :param model: path to the model, which can be TensorFlow Lite or Keras - :param output: path to the file where the report will be saved - :param evaluate_on: list of the backends to use for evaluation + # Set category based on checks to perform (i.e. "compatibility" and/or + # "performance"). + # If no check type is specified, "compatibility" is the default category. + if compatibility and performance: + category = {"compatibility", "performance"} + elif performance: + category = {"performance"} + else: + category = {"compatibility"} - Example: - Run command for the target profile ethos-u55-256 and - the provided TensorFlow Lite model and print report on the standard output + validate_check_target_profile(target_profile, category) + validated_backend = validate_backend(target_profile, backend) - >>> from mlia.api import ExecutionContext - >>> from mlia.cli.logging import setup_logging - >>> setup_logging() - >>> from mlia.cli.commands import performance - >>> performance(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256", - "model.tflite") - """ get_advice( target_profile, model, - "performance", - output=output, + category, + output=formatted_output, context=ctx, - backends=evaluate_on, + backends=validated_backend, ) -def optimization( +def optimize( # pylint: disable=too-many-arguments ctx: ExecutionContext, target_profile: str, model: str, - optimization_type: str, - optimization_target: str, + pruning: bool, + clustering: bool, + pruning_target: float | None, + clustering_target: int | None, layers_to_optimize: list[str] | None = None, - output: PathOrFileLike | None = None, - evaluate_on: list[str] | None = None, + output: Path | None = None, + json: bool = False, + backend: list[str] | None = None, ) -> None: """Show the performance improvements (if any) after applying the optimizations. @@ -201,43 +124,54 @@ def optimization( the inference performance (if possible). :param ctx: execution context - :param target: target profile identifier. Will load appropriate parameters + :param target_profile: target profile identifier. Will load appropriate parameters from the profile.json file based on this argument. :param model: path to the TensorFlow Lite model - :param optimization_type: list of the optimization techniques separated - by comma, e.g. 'pruning,clustering' - :param optimization_target: list of the corresponding targets for - the provided optimization techniques, e.g. '0.5,32' + :param pruning: perform pruning optimization (default if no option specified) + :param clustering: perform clustering optimization + :param clustering_target: clustering optimization target + :param pruning_target: pruning optimization target :param layers_to_optimize: list of the layers of the model which should be optimized, if None then all layers are used :param output: path to the file where the report will be saved - :param evaluate_on: list of the backends to use for evaluation + :param json: set the output format to json + :param backend: list of the backends to use for evaluation Example: Run command for the target profile ethos-u55-256 and the provided TensorFlow Lite model and print report on the standard output >>> from mlia.cli.logging import setup_logging + >>> from mlia.api import ExecutionContext >>> setup_logging() - >>> from mlia.cli.commands import optimization - >>> optimization(ExecutionContext(working_dir="mlia_output"), - target="ethos-u55-256", - "model.tflite", "pruning", "0.5") + >>> from mlia.cli.commands import optimize + >>> optimize(ExecutionContext(), + target_profile="ethos-u55-256", + model="model.tflite", pruning=True, + clustering=False, pruning_target=0.5, + clustering_target=None) """ - opt_params = parse_optimization_parameters( - optimization_type, - optimization_target, - layers_to_optimize=layers_to_optimize, + opt_params = ( + parse_optimization_parameters( # pylint: disable=too-many-function-args + pruning, + clustering, + pruning_target, + clustering_target, + layers_to_optimize, + ) ) + formatted_output = parse_output_parameters(output, json) + validated_backend = validate_backend(target_profile, backend) + get_advice( target_profile, model, - "optimization", + {"optimization"}, optimization_targets=opt_params, - output=output, + output=formatted_output, context=ctx, - backends=evaluate_on, + backends=validated_backend, ) diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py index 2d694dc..680b4b6 100644 --- a/src/mlia/cli/config.py +++ b/src/mlia/cli/config.py @@ -1,10 +1,13 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Environment configuration functions.""" from __future__ import annotations import logging from functools import lru_cache +from typing import List +from typing import Optional +from typing import TypedDict from mlia.backend.corstone.install import get_corstone_installations from mlia.backend.install import supported_backends @@ -14,6 +17,9 @@ from mlia.backend.tosa_checker.install import get_tosa_backend_installation logger = logging.getLogger(__name__) +DEFAULT_PRUNING_TARGET = 0.5 +DEFAULT_CLUSTERING_TARGET = 32 + def get_installation_manager(noninteractive: bool = False) -> InstallationManager: """Return installation manager.""" @@ -26,7 +32,7 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage @lru_cache def get_available_backends() -> list[str]: """Return list of the available backends.""" - available_backends = ["Vela"] + available_backends = ["Vela", "tosa-checker", "armnn-tflitedelegate"] # Add backends using backend manager manager = get_installation_manager() @@ -41,9 +47,10 @@ def get_available_backends() -> list[str]: # List of mutually exclusive Corstone backends ordered by priority _CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300") +_NON_ETHOS_U_BACKENDS = ("tosa-checker", "armnn-tflitedelegate") -def get_default_backends() -> list[str]: +def get_ethos_u_default_backends() -> list[str]: """Get default backends for evaluation.""" backends = get_available_backends() @@ -57,9 +64,34 @@ def get_default_backends() -> list[str]: ] break + # Filter out non ethos-u backends + backends = [x for x in backends if x not in _NON_ETHOS_U_BACKENDS] return backends def is_corstone_backend(backend: str) -> bool: """Check if the given backend is a Corstone backend.""" return backend in _CORSTONE_EXCLUSIVE_PRIORITY + + +BackendCompatibility = TypedDict( + "BackendCompatibility", + { + "partial-match": bool, + "backends": List[str], + "default-return": Optional[List[str]], + "use-custom-return": bool, + "custom-return": Optional[List[str]], + }, +) + + +def get_default_backends() -> dict[str, list[str]]: + """Return default backends for all targets.""" + ethos_u_defaults = get_ethos_u_default_backends() + return { + "ethos-u55": ethos_u_defaults, + "ethos-u65": ethos_u_defaults, + "tosa": ["tosa-checker"], + "cortex-a": ["armnn-tflitedelegate"], + } diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py index acec837..ac64581 100644 --- a/src/mlia/cli/helpers.py +++ b/src/mlia/cli/helpers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for various helper classes.""" from __future__ import annotations @@ -29,9 +29,9 @@ class CLIActionResolver(ActionResolver): return [ *keras_note, - "For example: mlia optimization --optimization-type " - f"pruning,clustering --optimization-target 0.5,32 {model_path}", - "For more info: mlia optimization --help", + f"For example: mlia optimize {model_path} --pruning --clustering " + "--pruning-target 0.5 --clustering-target 32", + "For more info: mlia optimize --help", ] @staticmethod @@ -41,14 +41,17 @@ class CLIActionResolver(ActionResolver): opt_settings: list[OptimizationSettings], ) -> list[str]: """Return specific optimization command description.""" - opt_types = ",".join(opt.optimization_type for opt in opt_settings) - opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings) + opt_types = " ".join("--" + opt.optimization_type for opt in opt_settings) + opt_targs_strings = ["--pruning-target", "--clustering-target"] + opt_targs = ",".join( + f"{opt_targs_strings[i]} {opt.optimization_target}" + for i, opt in enumerate(opt_settings) + ) return [ - "For more info: mlia optimization --help", + "For more info: mlia optimize --help", "Optimization command: " - f"mlia optimization --optimization-type {opt_types} " - f"--optimization-target {opt_targs}{device_opts} {model_path}", + f"mlia optimize {model_path}{device_opts} {opt_types} {opt_targs}", ] def apply_optimizations(self, **kwargs: Any) -> list[str]: @@ -65,13 +68,6 @@ class CLIActionResolver(ActionResolver): return [] - def supported_operators_info(self) -> list[str]: - """Return command details for generating supported ops report.""" - return [ - "For guidance on supported operators, run: mlia operators " - "--supported-ops-report", - ] - def check_performance(self) -> list[str]: """Return command details for checking performance.""" model_path, device_opts = self._get_model_and_device_opts() @@ -80,7 +76,7 @@ class CLIActionResolver(ActionResolver): return [ "Check the estimated performance by running the following command: ", - f"mlia performance{device_opts} {model_path}", + f"mlia check {model_path}{device_opts} --performance", ] def check_operator_compatibility(self) -> list[str]: @@ -91,16 +87,16 @@ class CLIActionResolver(ActionResolver): return [ "Try running the following command to verify that:", - f"mlia operators{device_opts} {model_path}", + f"mlia check {model_path}{device_opts}", ] def operator_compatibility_details(self) -> list[str]: """Return command details for op compatibility.""" - return ["For more details, run: mlia operators --help"] + return ["For more details, run: mlia check --help"] def optimization_details(self) -> list[str]: """Return command details for optimization.""" - return ["For more info, see: mlia optimization --help"] + return ["For more info, see: mlia optimize --help"] def _get_model_and_device_opts( self, separate_device_opts: bool = True diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index ac60308..1102d45 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI main entry point.""" from __future__ import annotations @@ -8,32 +8,28 @@ import logging import sys from functools import partial from inspect import signature -from pathlib import Path from mlia import __version__ from mlia.backend.errors import BackendUnavailableError from mlia.backend.registry import registry as backend_registry -from mlia.cli.commands import all_tests from mlia.cli.commands import backend_install from mlia.cli.commands import backend_list from mlia.cli.commands import backend_uninstall -from mlia.cli.commands import operators -from mlia.cli.commands import optimization -from mlia.cli.commands import performance +from mlia.cli.commands import check +from mlia.cli.commands import optimize from mlia.cli.common import CommandInfo from mlia.cli.helpers import CLIActionResolver from mlia.cli.logging import setup_logging from mlia.cli.options import add_backend_install_options +from mlia.cli.options import add_backend_options from mlia.cli.options import add_backend_uninstall_options -from mlia.cli.options import add_custom_supported_operators_options +from mlia.cli.options import add_check_category_options from mlia.cli.options import add_debug_options -from mlia.cli.options import add_evaluation_options from mlia.cli.options import add_keras_model_options +from mlia.cli.options import add_model_options from mlia.cli.options import add_multi_optimization_options -from mlia.cli.options import add_optional_tflite_model_options from mlia.cli.options import add_output_options from mlia.cli.options import add_target_options -from mlia.cli.options import add_tflite_model_options from mlia.core.context import ExecutionContext from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError @@ -60,50 +56,30 @@ def get_commands() -> list[CommandInfo]: """Return commands configuration.""" return [ CommandInfo( - all_tests, - ["all"], - [ - add_target_options, - add_keras_model_options, - add_multi_optimization_options, - add_output_options, - add_debug_options, - add_evaluation_options, - ], - True, - ), - CommandInfo( - operators, - ["ops"], + check, + [], [ + add_model_options, add_target_options, - add_optional_tflite_model_options, + add_backend_options, + add_check_category_options, add_output_options, - add_custom_supported_operators_options, add_debug_options, ], ), CommandInfo( - performance, - ["perf"], - [ - partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), - add_tflite_model_options, - add_output_options, - add_debug_options, - add_evaluation_options, - ], - ), - CommandInfo( - optimization, - ["opt"], + optimize, + [], [ - partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), add_keras_model_options, + partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), + partial( + add_backend_options, + backends_to_skip=["tosa-checker", "armnn-tflitedelegate"], + ), add_multi_optimization_options, add_output_options, add_debug_options, - add_evaluation_options, ], ), ] @@ -184,13 +160,12 @@ def setup_context( ) -> tuple[ExecutionContext, dict]: """Set up context and resolve function parameters.""" ctx = ExecutionContext( - working_dir=args.working_dir, - verbose="verbose" in args and args.verbose, + verbose="debug" in args and args.debug, action_resolver=CLIActionResolver(vars(args)), ) # these parameters should not be passed into command function - skipped_params = ["func", "command", "working_dir", "verbose"] + skipped_params = ["func", "command", "debug"] # pass these parameters only if command expects them expected_params = [context_var_name] @@ -219,6 +194,9 @@ def run_command(args: argparse.Namespace) -> int: try: logger.info(INFO_MESSAGE) + logger.info( + "\nThis execution of MLIA uses working directory: %s", ctx.working_dir + ) args.func(**func_args) return 0 except KeyboardInterrupt: @@ -246,22 +224,19 @@ def run_command(args: argparse.Namespace) -> int: f"Please check the log files in the {ctx.logs_path} for more details" ) if not ctx.verbose: - err_advice_message += ", or enable verbose mode (--verbose)" + err_advice_message += ", or enable debug mode (--debug)" logger.error(err_advice_message) - + finally: + logger.info( + "This execution of MLIA used working directory: %s", ctx.working_dir + ) return 1 def init_common_parser() -> argparse.ArgumentParser: """Init common parser.""" parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) - parser.add_argument( - "--working-dir", - default=f"{Path.cwd() / 'mlia_output'}", - help="Path to the directory where MLIA will store logs, " - "models, etc. (default: %(default)s)", - ) return parser diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index 8ea4250..bae6219 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for the CLI options.""" from __future__ import annotations @@ -8,37 +8,48 @@ from pathlib import Path from typing import Any from typing import Callable +from mlia.cli.config import DEFAULT_CLUSTERING_TARGET +from mlia.cli.config import DEFAULT_PRUNING_TARGET from mlia.cli.config import get_available_backends -from mlia.cli.config import get_default_backends from mlia.cli.config import is_corstone_backend -from mlia.core.reporting import OUTPUT_FORMATS +from mlia.core.common import FormattedFilePath from mlia.utils.filesystem import get_supported_profile_names -from mlia.utils.types import is_number + + +def add_check_category_options(parser: argparse.ArgumentParser) -> None: + """Add check category type options.""" + parser.add_argument( + "--performance", action="store_true", help="Perform performance checks." + ) + + parser.add_argument( + "--compatibility", + action="store_true", + help="Perform compatibility checks. (default)", + ) def add_target_options( - parser: argparse.ArgumentParser, profiles_to_skip: list[str] | None = None + parser: argparse.ArgumentParser, + profiles_to_skip: list[str] | None = None, + required: bool = True, ) -> None: """Add target specific options.""" target_profiles = get_supported_profile_names() if profiles_to_skip: target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip] - default_target_profile = None - default_help = "" - if target_profiles: - default_target_profile = target_profiles[0] - default_help = " (default: %(default)s)" - target_group = parser.add_argument_group("target options") target_group.add_argument( + "-t", "--target-profile", choices=target_profiles, - default=default_target_profile, + required=required, + default="", help="Target profile that will set the target options " "such as target, mac value, memory mode, etc. " - f"For the values associated with each target profile " - f" please refer to the documentation {default_help}.", + "For the values associated with each target profile " + "please refer to the documentation.", ) @@ -47,59 +58,47 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None: multi_optimization_group = parser.add_argument_group("optimization options") multi_optimization_group.add_argument( - "--optimization-type", - default="pruning,clustering", - help="List of the optimization types separated by comma (default: %(default)s)", + "--pruning", action="store_true", help="Apply pruning optimization." ) + multi_optimization_group.add_argument( - "--optimization-target", - default="0.5,32", - help="""List of the optimization targets separated by comma, - (for pruning this is sparsity between (0,1), - for clustering this is the number of clusters (positive integer)) - (default: %(default)s)""", + "--clustering", action="store_true", help="Apply clustering optimization." ) + multi_optimization_group.add_argument( + "--pruning-target", + type=float, + help="Sparsity to be reached during optimization " + f"(default: {DEFAULT_PRUNING_TARGET})", + ) -def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None: - """Add optional model specific options.""" - model_group = parser.add_argument_group("TensorFlow Lite model options") - # make model parameter optional - model_group.add_argument( - "model", nargs="?", help="TensorFlow Lite model (optional)" + multi_optimization_group.add_argument( + "--clustering-target", + type=int, + help="Number of clusters to reach during optimization " + f"(default: {DEFAULT_CLUSTERING_TARGET})", ) -def add_tflite_model_options(parser: argparse.ArgumentParser) -> None: +def add_model_options(parser: argparse.ArgumentParser) -> None: """Add model specific options.""" - model_group = parser.add_argument_group("TensorFlow Lite model options") - model_group.add_argument("model", help="TensorFlow Lite model") + parser.add_argument("model", help="TensorFlow Lite model or Keras model") def add_output_options(parser: argparse.ArgumentParser) -> None: """Add output specific options.""" - valid_extensions = OUTPUT_FORMATS - - def check_extension(filename: str) -> str: - """Check extension of the provided file.""" - suffix = Path(filename).suffix - if suffix.startswith("."): - suffix = suffix[1:] - - if suffix.lower() not in valid_extensions: - parser.error(f"Unsupported format '{suffix}'") - - return filename - output_group = parser.add_argument_group("output options") output_group.add_argument( + "-o", "--output", - type=check_extension, - help=( - "Name of the file where report will be saved. " - "Report format is automatically detected based on the file extension. " - f"Supported formats are: {', '.join(valid_extensions)}" - ), + type=Path, + help=("Name of the file where the report will be saved."), + ) + + output_group.add_argument( + "--json", + action="store_true", + help=("Format to use for the output (requires --output argument to be set)."), ) @@ -107,7 +106,11 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None: """Add debug options.""" debug_group = parser.add_argument_group("debug options") debug_group.add_argument( - "--verbose", default=False, action="store_true", help="Produce verbose output" + "-d", + "--debug", + default=False, + action="store_true", + help="Produce verbose output", ) @@ -117,20 +120,6 @@ def add_keras_model_options(parser: argparse.ArgumentParser) -> None: model_group.add_argument("model", help="Keras model") -def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None: - """Add custom options for the command 'operators'.""" - parser.add_argument( - "--supported-ops-report", - action="store_true", - default=False, - help=( - "Generate the SUPPORTED_OPS.md file in the " - "current working directory and exit " - "(Ethos-U target profiles only)" - ), - ) - - def add_backend_install_options(parser: argparse.ArgumentParser) -> None: """Add options for the backends configuration.""" @@ -176,10 +165,11 @@ def add_backend_uninstall_options(parser: argparse.ArgumentParser) -> None: ) -def add_evaluation_options(parser: argparse.ArgumentParser) -> None: +def add_backend_options( + parser: argparse.ArgumentParser, backends_to_skip: list[str] | None = None +) -> None: """Add evaluation options.""" available_backends = get_available_backends() - default_backends = get_default_backends() def only_one_corstone_checker() -> Callable: """ @@ -202,41 +192,70 @@ def add_evaluation_options(parser: argparse.ArgumentParser) -> None: return check - evaluation_group = parser.add_argument_group("evaluation options") + # Remove backends to skip + if backends_to_skip: + available_backends = [ + x for x in available_backends if x not in backends_to_skip + ] + + evaluation_group = parser.add_argument_group("backend options") evaluation_group.add_argument( - "--evaluate-on", - help="Backends to use for evaluation (default: %(default)s)", - nargs="*", + "-b", + "--backend", + help="Backends to use for evaluation.", + nargs="+", choices=available_backends, - default=default_backends, type=only_one_corstone_checker(), ) +def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None: + """Parse and return path and file format as FormattedFilePath.""" + if not path and json: + raise argparse.ArgumentError( + None, + "To enable JSON output you need to specify the output path. " + "(e.g. --output out.json --json)", + ) + if not path: + return None + if json: + return FormattedFilePath(path, "json") + + return FormattedFilePath(path, "plain_text") + + def parse_optimization_parameters( - optimization_type: str, - optimization_target: str, - sep: str = ",", + pruning: bool = False, + clustering: bool = False, + pruning_target: float | None = None, + clustering_target: int | None = None, layers_to_optimize: list[str] | None = None, ) -> list[dict[str, Any]]: """Parse provided optimization parameters.""" - if not optimization_type: - raise Exception("Optimization type is not provided") + opt_types = [] + opt_targets = [] - if not optimization_target: - raise Exception("Optimization target is not provided") + if clustering_target and not clustering: + raise argparse.ArgumentError( + None, + "To enable clustering optimization you need to include the " + "`--clustering` flag in your command.", + ) - opt_types = optimization_type.split(sep) - opt_targets = optimization_target.split(sep) + if not pruning_target: + pruning_target = DEFAULT_PRUNING_TARGET - if len(opt_types) != len(opt_targets): - raise Exception("Wrong number of optimization targets and types") + if not clustering_target: + clustering_target = DEFAULT_CLUSTERING_TARGET - non_numeric_targets = [ - opt_target for opt_target in opt_targets if not is_number(opt_target) - ] - if len(non_numeric_targets) > 0: - raise Exception("Non numeric value for the optimization target") + if (pruning is False and clustering is False) or pruning: + opt_types.append("pruning") + opt_targets.append(pruning_target) + + if clustering: + opt_types.append("clustering") + opt_targets.append(clustering_target) optimizer_params = [ { @@ -256,7 +275,7 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]: return [] parser = argparse.ArgumentParser() - add_target_options(parser) + add_target_options(parser, required=False) args = parser.parse_args([]) params_name = { diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py index 6c9dde1..53df001 100644 --- a/src/mlia/core/common.py +++ b/src/mlia/core/common.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Common module. @@ -13,6 +13,9 @@ from enum import auto from enum import Flag from typing import Any +from mlia.core.typing import OutputFormat +from mlia.core.typing import PathOrFileLike + # This type is used as type alias for the items which are being passed around # in advisor workflow. There are no restrictions on the type of the # object. This alias used only to emphasize the nature of the input/output @@ -20,31 +23,55 @@ from typing import Any DataItem = Any +class FormattedFilePath: + """Class used to keep track of the format that a path points to.""" + + def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None: + """Init FormattedFilePath.""" + self._path = path + self._fmt = fmt + + @property + def fmt(self) -> OutputFormat: + """Return file format.""" + return self._fmt + + @property + def path(self) -> PathOrFileLike: + """Return file path.""" + return self._path + + def __eq__(self, other: object) -> bool: + """Check for equality with other objects.""" + if isinstance(other, FormattedFilePath): + return other.fmt == self.fmt and other.path == self.path + + return False + + def __repr__(self) -> str: + """Represent object.""" + return f"FormattedFilePath {self.path=}, {self.fmt=}" + + class AdviceCategory(Flag): """Advice category. Enumeration of advice categories supported by ML Inference Advisor. """ - OPERATORS = auto() + COMPATIBILITY = auto() PERFORMANCE = auto() OPTIMIZATION = auto() - ALL = ( - # pylint: disable=unsupported-binary-operation - OPERATORS - | PERFORMANCE - | OPTIMIZATION - # pylint: enable=unsupported-binary-operation - ) @classmethod - def from_string(cls, value: str) -> AdviceCategory: + def from_string(cls, values: set[str]) -> set[AdviceCategory]: """Resolve enum value from string value.""" category_names = [item.name for item in AdviceCategory] - if not value or value.upper() not in category_names: - raise Exception(f"Invalid advice category {value}") + for advice_value in values: + if advice_value.upper() not in category_names: + raise Exception(f"Invalid advice category {advice_value}") - return AdviceCategory[value.upper()] + return {AdviceCategory[value.upper()] for value in values} class NamedEntity(ABC): diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py index a4737bb..94aa885 100644 --- a/src/mlia/core/context.py +++ b/src/mlia/core/context.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Context module. @@ -10,6 +10,7 @@ parameters). from __future__ import annotations import logging +import tempfile from abc import ABC from abc import abstractmethod from pathlib import Path @@ -54,7 +55,7 @@ class Context(ABC): @property @abstractmethod - def advice_category(self) -> AdviceCategory: + def advice_category(self) -> set[AdviceCategory]: """Return advice category.""" @property @@ -71,7 +72,7 @@ class Context(ABC): def update( self, *, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory], event_handlers: list[EventHandler], config_parameters: Mapping[str, Any], ) -> None: @@ -79,11 +80,11 @@ class Context(ABC): def category_enabled(self, category: AdviceCategory) -> bool: """Check if category enabled.""" - return category == self.advice_category + return category in self.advice_category def any_category_enabled(self, *categories: AdviceCategory) -> bool: """Return true if any category is enabled.""" - return self.advice_category in categories + return all(category in self.advice_category for category in categories) def register_event_handlers(self) -> None: """Register event handlers.""" @@ -96,7 +97,7 @@ class ExecutionContext(Context): def __init__( self, *, - advice_category: AdviceCategory = AdviceCategory.ALL, + advice_category: set[AdviceCategory] = None, config_parameters: Mapping[str, Any] | None = None, working_dir: str | Path | None = None, event_handlers: list[EventHandler] | None = None, @@ -108,7 +109,7 @@ class ExecutionContext(Context): ) -> None: """Init execution context. - :param advice_category: requested advice category + :param advice_category: requested advice categories :param config_parameters: dictionary like object with input parameters :param working_dir: path to the directory that will be used as a place to store temporary files, logs, models. If not provided then @@ -124,13 +125,13 @@ class ExecutionContext(Context): :param action_resolver: instance of the action resolver that could make advice actionable """ - self._advice_category = advice_category + self._advice_category = advice_category or {AdviceCategory.COMPATIBILITY} self._config_parameters = config_parameters - self._working_dir_path = Path.cwd() if working_dir: self._working_dir_path = Path(working_dir) - self._working_dir_path.mkdir(exist_ok=True) + else: + self._working_dir_path = generate_temp_workdir() self._event_handlers = event_handlers self._event_publisher = event_publisher or DefaultEventPublisher() @@ -140,12 +141,17 @@ class ExecutionContext(Context): self._action_resolver = action_resolver or APIActionResolver() @property - def advice_category(self) -> AdviceCategory: + def working_dir(self) -> Path: + """Return working dir path.""" + return self._working_dir_path + + @property + def advice_category(self) -> set[AdviceCategory]: """Return advice category.""" return self._advice_category @advice_category.setter - def advice_category(self, advice_category: AdviceCategory) -> None: + def advice_category(self, advice_category: set[AdviceCategory]) -> None: """Setter for the advice category.""" self._advice_category = advice_category @@ -194,7 +200,7 @@ class ExecutionContext(Context): def update( self, *, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory], event_handlers: list[EventHandler], config_parameters: Mapping[str, Any], ) -> None: @@ -206,7 +212,9 @@ class ExecutionContext(Context): def __str__(self) -> str: """Return string representation.""" category = ( - "" if self.advice_category is None else self.advice_category.name + "" + if self.advice_category is None + else {x.name for x in self.advice_category} ) return ( @@ -215,3 +223,9 @@ class ExecutionContext(Context): f"config_parameters={self.config_parameters}, " f"verbose={self.verbose}" ) + + +def generate_temp_workdir() -> Path: + """Generate a temporary working dir and returns the path.""" + working_dir = tempfile.mkdtemp(suffix=None, prefix="mlia-", dir=None) + return Path(working_dir) diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py index a3255ae..6e50934 100644 --- a/src/mlia/core/handlers.py +++ b/src/mlia/core/handlers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Event handlers module.""" from __future__ import annotations @@ -9,6 +9,7 @@ from typing import Callable from mlia.core.advice_generation import Advice from mlia.core.advice_generation import AdviceEvent +from mlia.core.common import FormattedFilePath from mlia.core.events import ActionFinishedEvent from mlia.core.events import ActionStartedEvent from mlia.core.events import AdviceStageFinishedEvent @@ -26,7 +27,6 @@ from mlia.core.events import ExecutionFinishedEvent from mlia.core.events import ExecutionStartedEvent from mlia.core.reporting import Report from mlia.core.reporting import Reporter -from mlia.core.reporting import resolve_output_format from mlia.core.typing import PathOrFileLike from mlia.utils.console import create_section_header @@ -101,12 +101,12 @@ class WorkflowEventsHandler(SystemEventsHandler): def __init__( self, formatter_resolver: Callable[[Any], Callable[[Any], Report]], - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, ) -> None: """Init event handler.""" - output_format = resolve_output_format(output) + output_format = output.fmt if output else "plain_text" self.reporter = Reporter(formatter_resolver, output_format) - self.output = output + self.output = output.path if output else None self.advice: list[Advice] = [] diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py index f4a9df6..ed43d04 100644 --- a/src/mlia/core/helpers.py +++ b/src/mlia/core/helpers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for various helper classes.""" # pylint: disable=unused-argument @@ -14,10 +14,6 @@ class ActionResolver: """Return action details for applying optimizations.""" return [] - def supported_operators_info(self) -> list[str]: - """Return action details for generating supported ops report.""" - return [] - def check_performance(self) -> list[str]: """Return action details for checking performance.""" return [] diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py index b96a6b5..19644b2 100644 --- a/src/mlia/core/reporting.py +++ b/src/mlia/core/reporting.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Reporting module.""" from __future__ import annotations @@ -639,14 +639,3 @@ def _apply_format_parameters( return report return wrapper - - -def resolve_output_format(output: PathOrFileLike | None) -> OutputFormat: - """Resolve output format based on the output name.""" - if isinstance(output, (str, Path)): - format_from_filename = Path(output).suffix.lstrip(".") - - if format_from_filename in OUTPUT_FORMATS: - return cast(OutputFormat, format_from_filename) - - return "plain_text" diff --git a/src/mlia/target/cortex_a/advice_generation.py b/src/mlia/target/cortex_a/advice_generation.py index b68106e..98e8c06 100644 --- a/src/mlia/target/cortex_a/advice_generation.py +++ b/src/mlia/target/cortex_a/advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A advice generation.""" from functools import singledispatchmethod @@ -29,7 +29,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): """Produce advice.""" @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_model_is_cortex_a_compatible( self, data_item: ModelIsCortexACompatible ) -> None: @@ -43,7 +43,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): ) @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_model_is_not_cortex_a_compatible( self, data_item: ModelIsNotCortexACompatible ) -> None: @@ -83,7 +83,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): ) @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_model_is_not_tflite_compatible( self, data_item: ModelIsNotTFLiteCompatible ) -> None: @@ -127,7 +127,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): ) @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_tflite_check_failed( self, _data_item: TFLiteCompatibilityCheckFailed ) -> None: @@ -140,7 +140,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer): ) @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_model_has_custom_operators( self, _data_item: ModelHasCustomOperators ) -> None: diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py index 5912e38..b649f0d 100644 --- a/src/mlia/target/cortex_a/advisor.py +++ b/src/mlia/target/cortex_a/advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A MLIA module.""" from __future__ import annotations @@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory +from mlia.core.common import FormattedFilePath from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event -from mlia.core.typing import PathOrFileLike from mlia.target.cortex_a.advice_generation import CortexAAdviceProducer from mlia.target.cortex_a.config import CortexAConfiguration from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer @@ -38,7 +38,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor): collectors: list[DataCollector] = [] - if AdviceCategory.OPERATORS in context.advice_category: + if context.category_enabled(AdviceCategory.COMPATIBILITY): collectors.append(CortexAOperatorCompatibility(model)) return collectors @@ -67,7 +67,7 @@ def configure_and_get_cortexa_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, **_extra_args: Any, ) -> InferenceAdvisor: """Create and configure Cortex-A advisor.""" diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py index b2d5faa..d6acde5 100644 --- a/src/mlia/target/cortex_a/handlers.py +++ b/src/mlia/target/cortex_a/handlers.py @@ -1,13 +1,13 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Event handler.""" from __future__ import annotations import logging +from mlia.core.common import FormattedFilePath from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler -from mlia.core.typing import PathOrFileLike from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.target.cortex_a.events import CortexAAdvisorEventHandler from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler): """CLI event handler.""" - def __init__(self, output: PathOrFileLike | None = None) -> None: + def __init__(self, output: FormattedFilePath | None = None) -> None: """Init event handler.""" super().__init__(cortex_a_formatters, output) diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py index edd78fd..daae4f4 100644 --- a/src/mlia/target/ethos_u/advice_generation.py +++ b/src/mlia/target/ethos_u/advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U advice generation.""" from __future__ import annotations @@ -26,7 +26,7 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): """Produce advice.""" @produce_advice.register - @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_cpu_only_ops(self, data_item: HasCPUOnlyOperators) -> None: """Advice for CPU only operators.""" cpu_only_ops = ",".join(sorted(set(data_item.cpu_only_ops))) @@ -40,11 +40,10 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): "Using operators that are supported by the NPU will " "improve performance.", ] - + self.context.action_resolver.supported_operators_info() ) @produce_advice.register - @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_unsupported_operators( self, data_item: HasUnsupportedOnNPUOperators ) -> None: @@ -60,21 +59,25 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): ) @produce_advice.register - @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_all_operators_supported( self, _data_item: AllOperatorsSupportedOnNPU ) -> None: """Advice if all operators supported.""" - self.add_advice( - [ - "You don't have any unsupported operators, your model will " - "run completely on NPU." - ] - + self.context.action_resolver.check_performance() - ) + advice = [ + "You don't have any unsupported operators, your model will " + "run completely on NPU." + ] + if self.context.advice_category != ( + AdviceCategory.COMPATIBILITY, + AdviceCategory.PERFORMANCE, + ): + advice += self.context.action_resolver.check_performance() + + self.add_advice(advice) @produce_advice.register - @advice_category(AdviceCategory.OPTIMIZATION, AdviceCategory.ALL) + @advice_category(AdviceCategory.OPTIMIZATION) def handle_optimization_results(self, data_item: OptimizationResults) -> None: """Advice based on optimization results.""" if not data_item.diffs or len(data_item.diffs) != 1: @@ -202,5 +205,6 @@ class EthosUStaticAdviceProducer(ContextAwareAdviceProducer): ) ], } - - return advice_per_category.get(self.context.advice_category, []) + if len(self.context.advice_category) == 1: + return advice_per_category.get(list(self.context.advice_category)[0], []) + return [] diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py index b9d64ff..640c3e1 100644 --- a/src/mlia/target/ethos_u/advisor.py +++ b/src/mlia/target/ethos_u/advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U MLIA module.""" from __future__ import annotations @@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor from mlia.core.common import AdviceCategory +from mlia.core.common import FormattedFilePath from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event -from mlia.core.typing import PathOrFileLike from mlia.nn.tensorflow.utils import is_tflite_model from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer @@ -46,7 +46,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): collectors: list[DataCollector] = [] - if AdviceCategory.OPERATORS in context.advice_category: + if context.category_enabled(AdviceCategory.COMPATIBILITY): collectors.append(EthosUOperatorCompatibility(model, device)) # Performance and optimization are mutually exclusive. @@ -57,18 +57,18 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor): raise Exception( "Command 'optimization' is not supported for TensorFlow Lite files." ) - if AdviceCategory.PERFORMANCE in context.advice_category: + if context.category_enabled(AdviceCategory.PERFORMANCE): collectors.append(EthosUPerformance(model, device, backends)) else: # Keras/SavedModel: Prefer optimization - if AdviceCategory.OPTIMIZATION in context.advice_category: + if context.category_enabled(AdviceCategory.OPTIMIZATION): optimization_settings = self._get_optimization_settings(context) collectors.append( EthosUOptimizationPerformance( model, device, optimization_settings, backends ) ) - elif AdviceCategory.PERFORMANCE in context.advice_category: + elif context.category_enabled(AdviceCategory.PERFORMANCE): collectors.append(EthosUPerformance(model, device, backends)) return collectors @@ -126,7 +126,7 @@ def configure_and_get_ethosu_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, **extra_args: Any, ) -> InferenceAdvisor: """Create and configure Ethos-U advisor.""" diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py index 84a9554..91f6015 100644 --- a/src/mlia/target/ethos_u/handlers.py +++ b/src/mlia/target/ethos_u/handlers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Event handler.""" from __future__ import annotations @@ -6,9 +6,9 @@ from __future__ import annotations import logging from mlia.backend.vela.compat import Operators +from mlia.core.common import FormattedFilePath from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler -from mlia.core.typing import PathOrFileLike from mlia.target.ethos_u.events import EthosUAdvisorEventHandler from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler): """CLI event handler.""" - def __init__(self, output: PathOrFileLike | None = None) -> None: + def __init__(self, output: FormattedFilePath | None = None) -> None: """Init event handler.""" super().__init__(ethos_u_formatters, output) diff --git a/src/mlia/target/tosa/advice_generation.py b/src/mlia/target/tosa/advice_generation.py index f531b84..b8b9abf 100644 --- a/src/mlia/target/tosa/advice_generation.py +++ b/src/mlia/target/tosa/advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA advice generation.""" from functools import singledispatchmethod @@ -19,7 +19,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer): """Produce advice.""" @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_model_is_tosa_compatible( self, _data_item: ModelIsTOSACompatible ) -> None: @@ -27,7 +27,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer): self.add_advice(["Model is fully TOSA compatible."]) @produce_advice.register - @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_model_is_not_tosa_compatible( self, _data_item: ModelIsNotTOSACompatible ) -> None: diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py index 2739dfd..4851113 100644 --- a/src/mlia/target/tosa/advisor.py +++ b/src/mlia/target/tosa/advisor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA advisor.""" from __future__ import annotations @@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceCategory from mlia.core.advice_generation import AdviceProducer from mlia.core.advisor import DefaultInferenceAdvisor from mlia.core.advisor import InferenceAdvisor +from mlia.core.common import FormattedFilePath from mlia.core.context import Context from mlia.core.context import ExecutionContext from mlia.core.data_analysis import DataAnalyzer from mlia.core.data_collection import DataCollector from mlia.core.events import Event -from mlia.core.typing import PathOrFileLike from mlia.target.tosa.advice_generation import TOSAAdviceProducer from mlia.target.tosa.config import TOSAConfiguration from mlia.target.tosa.data_analysis import TOSADataAnalyzer @@ -38,7 +38,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor): collectors: list[DataCollector] = [] - if AdviceCategory.OPERATORS in context.advice_category: + if context.category_enabled(AdviceCategory.COMPATIBILITY): collectors.append(TOSAOperatorCompatibility(model)) return collectors @@ -69,7 +69,7 @@ def configure_and_get_tosa_advisor( context: ExecutionContext, target_profile: str, model: str | Path, - output: PathOrFileLike | None = None, + output: FormattedFilePath | None = None, **_extra_args: Any, ) -> InferenceAdvisor: """Create and configure TOSA advisor.""" diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py index 863558c..1037ba1 100644 --- a/src/mlia/target/tosa/handlers.py +++ b/src/mlia/target/tosa/handlers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """TOSA Advisor event handlers.""" # pylint: disable=R0801 @@ -7,9 +7,9 @@ from __future__ import annotations import logging from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo +from mlia.core.common import FormattedFilePath from mlia.core.events import CollectedDataEvent from mlia.core.handlers import WorkflowEventsHandler -from mlia.core.typing import PathOrFileLike from mlia.target.tosa.events import TOSAAdvisorEventHandler from mlia.target.tosa.events import TOSAAdvisorStartedEvent from mlia.target.tosa.reporters import tosa_formatters @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler): """Event handler for TOSA advisor.""" - def __init__(self, output: PathOrFileLike | None = None) -> None: + def __init__(self, output: FormattedFilePath | None = None) -> None: """Init event handler.""" super().__init__(tosa_formatters, output) diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py index ea067b8..0769968 100644 --- a/src/mlia/utils/types.py +++ b/src/mlia/utils/types.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Types related utility functions.""" from __future__ import annotations @@ -19,7 +19,7 @@ def is_number(value: str) -> bool: """Return true if string contains a number.""" try: float(value) - except ValueError: + except (ValueError, TypeError): return False return True diff --git a/tests/test_api.py b/tests/test_api.py index fbc558b..0bbc3ae 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,15 +1,13 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the API functions.""" from __future__ import annotations from pathlib import Path from unittest.mock import MagicMock -from unittest.mock import patch import pytest -from mlia.api import generate_supported_operators_report from mlia.api import get_advice from mlia.api import get_advisor from mlia.core.common import AdviceCategory @@ -22,63 +20,68 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor def test_get_advice_no_target_provided(test_keras_model: Path) -> None: """Test getting advice when no target provided.""" with pytest.raises(Exception, match="Target profile is not provided"): - get_advice(None, test_keras_model, "all") # type: ignore + get_advice(None, test_keras_model, {"compatibility"}) # type: ignore def test_get_advice_wrong_category(test_keras_model: Path) -> None: """Test getting advice when wrong advice category provided.""" with pytest.raises(Exception, match="Invalid advice category unknown"): - get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore + get_advice("ethos-u55-256", test_keras_model, {"unknown"}) @pytest.mark.parametrize( "category, context, expected_category", [ [ - "all", + {"compatibility", "optimization"}, None, - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "optimization", + {"optimization"}, None, - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, ], [ - "operators", + {"compatibility"}, None, - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, ], [ - "performance", + {"performance"}, None, - AdviceCategory.PERFORMANCE, + {AdviceCategory.PERFORMANCE}, ], [ - "all", + {"compatibility", "optimization"}, ExecutionContext(), - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "all", - ExecutionContext(advice_category=AdviceCategory.PERFORMANCE), - AdviceCategory.ALL, + {"compatibility", "optimization"}, + ExecutionContext( + advice_category={ + AdviceCategory.COMPATIBILITY, + AdviceCategory.OPTIMIZATION, + } + ), + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "all", + {"compatibility", "optimization"}, ExecutionContext(config_parameters={"param": "value"}), - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], [ - "all", + {"compatibility", "optimization"}, ExecutionContext(event_handlers=[MagicMock()]), - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION}, ], ], ) def test_get_advice( monkeypatch: pytest.MonkeyPatch, - category: str, + category: set[str], context: ExecutionContext, expected_category: AdviceCategory, test_keras_model: Path, @@ -90,7 +93,7 @@ def test_get_advice( get_advice( "ethos-u55-256", test_keras_model, - category, # type: ignore + category, context=context, ) @@ -111,50 +114,3 @@ def test_get_advisor( tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model)) assert isinstance(tosa_advisor, TOSAInferenceAdvisor) - - -@pytest.mark.parametrize( - ["target_profile", "required_calls", "exception_msg"], - [ - [ - "ethos-u55-128", - "mlia.target.ethos_u.operators.generate_supported_operators_report", - None, - ], - [ - "ethos-u65-256", - "mlia.target.ethos_u.operators.generate_supported_operators_report", - None, - ], - [ - "tosa", - None, - "Generating a supported operators report is not " - "currently supported with TOSA target profile.", - ], - [ - "cortex-a", - None, - "Generating a supported operators report is not " - "currently supported with Cortex-A target profile.", - ], - [ - "Unknown", - None, - "Unable to find target profile Unknown", - ], - ], -) -def test_supported_ops_report_generator( - target_profile: str, required_calls: str | None, exception_msg: str | None -) -> None: - """Test supported operators report generator with different target profiles.""" - if exception_msg: - with pytest.raises(Exception) as exc: - generate_supported_operators_report(target_profile) - assert str(exc.value) == exception_msg - - if required_calls: - with patch(required_calls) as mock_method: - generate_supported_operators_report(target_profile) - mock_method.assert_called_once() diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py index bd50945..700534f 100644 --- a/tests/test_backend_config.py +++ b/tests/test_backend_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend config module.""" from mlia.backend.config import BackendConfiguration @@ -20,14 +20,14 @@ def test_system() -> None: def test_backend_config() -> None: """Test the class 'BackendConfiguration'.""" cfg = BackendConfiguration( - [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM + [AdviceCategory.COMPATIBILITY], [System.CURRENT], BackendType.CUSTOM ) - assert cfg.supported_advice == [AdviceCategory.OPERATORS] + assert cfg.supported_advice == [AdviceCategory.COMPATIBILITY] assert cfg.supported_systems == [System.CURRENT] assert cfg.type == BackendType.CUSTOM assert str(cfg) assert cfg.is_supported() - assert cfg.is_supported(advice=AdviceCategory.OPERATORS) + assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY) assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE) assert cfg.is_supported(check_system=True) assert cfg.is_supported(check_system=False) @@ -37,6 +37,6 @@ def test_backend_config() -> None: cfg.supported_systems = [UNSUPPORTED_SYSTEM] assert not cfg.is_supported(check_system=True) assert cfg.is_supported(check_system=False) - assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True) - assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False) + assert not cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=True) + assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=False) assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False) diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py index 31a20a0..703e699 100644 --- a/tests/test_backend_registry.py +++ b/tests/test_backend_registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend registry module.""" from __future__ import annotations @@ -18,7 +18,7 @@ from mlia.core.common import AdviceCategory ( ( "ArmNNTFLiteDelegate", - [AdviceCategory.OPERATORS], + [AdviceCategory.COMPATIBILITY], None, BackendType.BUILTIN, ), @@ -36,14 +36,14 @@ from mlia.core.common import AdviceCategory ), ( "TOSA-Checker", - [AdviceCategory.OPERATORS], + [AdviceCategory.COMPATIBILITY], [System.LINUX_AMD64], BackendType.WHEEL, ), ( "Vela", [ - AdviceCategory.OPERATORS, + AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE, AdviceCategory.OPTIMIZATION, ], diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py new file mode 100644 index 0000000..13514a5 --- /dev/null +++ b/tests/test_cli_command_validators.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for cli.command_validators module.""" +from __future__ import annotations + +import argparse +from unittest.mock import MagicMock + +import pytest + +from mlia.cli.command_validators import validate_backend +from mlia.cli.command_validators import validate_check_target_profile + + +@pytest.mark.parametrize( + "target_profile, category, expected_warnings, sys_exits", + [ + ["ethos-u55-256", {"compatibility", "performance"}, [], False], + ["ethos-u55-256", {"compatibility"}, [], False], + ["ethos-u55-256", {"performance"}, [], False], + [ + "tosa", + {"compatibility", "performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile tosa." + ) + ], + False, + ], + [ + "tosa", + {"performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile tosa. No operation was performed." + ) + ], + True, + ], + ["tosa", "compatibility", [], False], + [ + "cortex-a", + {"performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile cortex-a. " + "No operation was performed." + ) + ], + True, + ], + [ + "cortex-a", + {"compatibility", "performance"}, + [ + ( + "\nWARNING: Performance checks skipped as they cannot be " + "performed with target profile cortex-a." + ) + ], + False, + ], + ["cortex-a", "compatibility", [], False], + ], +) +def test_validate_check_target_profile( + caplog: pytest.LogCaptureFixture, + target_profile: str, + category: set[str], + expected_warnings: list[str], + sys_exits: bool, +) -> None: + """Test outcomes of category dependent target profile validation.""" + # Capture if program terminates + if sys_exits: + with pytest.raises(SystemExit) as sys_ex: + validate_check_target_profile(target_profile, category) + assert sys_ex.value.code == 0 + return + + validate_check_target_profile(target_profile, category) + + log_records = caplog.records + # Get all log records with level 30 (warning level) + warning_messages = {x.message for x in log_records if x.levelno == 30} + # Ensure the warnings coincide with the expected ones + assert warning_messages == set(expected_warnings) + + +@pytest.mark.parametrize( + "input_target_profile, input_backends, throws_exception," + "exception_message, output_backends", + [ + [ + "tosa", + ["Vela"], + True, + "Vela backend not supported with target-profile tosa.", + None, + ], + [ + "tosa", + ["Corstone-300, Vela"], + True, + "Corstone-300, Vela backend not supported with target-profile tosa.", + None, + ], + [ + "cortex-a", + ["Corstone-310", "tosa-checker"], + True, + "Corstone-310, tosa-checker backend not supported " + "with target-profile cortex-a.", + None, + ], + [ + "ethos-u55-256", + ["tosa-checker", "Corstone-310"], + True, + "tosa-checker backend not supported with target-profile ethos-u55-256.", + None, + ], + ["tosa", None, False, None, ["tosa-checker"]], + ["cortex-a", None, False, None, ["armnn-tflitedelegate"]], + ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]], + ["cortex-a", ["armnn-tflitedelegate"], False, None, ["armnn-tflitedelegate"]], + [ + "ethos-u55-256", + ["Vela", "Corstone-300"], + False, + None, + ["Vela", "Corstone-300"], + ], + [ + "ethos-u55-256", + None, + False, + None, + ["Vela", "Corstone-300"], + ], + ], +) +def test_validate_backend( + monkeypatch: pytest.MonkeyPatch, + input_target_profile: str, + input_backends: list[str] | None, + throws_exception: bool, + exception_message: str, + output_backends: list[str] | None, +) -> None: + """Test backend validation with target-profiles and backends.""" + monkeypatch.setattr( + "mlia.cli.config.get_available_backends", + MagicMock(return_value=["Vela", "Corstone-300"]), + ) + + if throws_exception: + with pytest.raises(argparse.ArgumentError) as err: + validate_backend(input_target_profile, input_backends) + assert str(err.value.message) == exception_message + return + + assert validate_backend(input_target_profile, input_backends) == output_backends diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index aed5c42..03ee9d2 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for cli.commands module.""" from __future__ import annotations @@ -14,9 +14,8 @@ from mlia.backend.manager import DefaultInstallationManager from mlia.cli.commands import backend_install from mlia.cli.commands import backend_list from mlia.cli.commands import backend_uninstall -from mlia.cli.commands import operators -from mlia.cli.commands import optimization -from mlia.cli.commands import performance +from mlia.cli.commands import check +from mlia.cli.commands import optimize from mlia.core.context import ExecutionContext from mlia.target.ethos_u.config import EthosUConfiguration from mlia.target.ethos_u.performance import MemoryUsage @@ -27,7 +26,7 @@ from mlia.target.ethos_u.performance import PerformanceMetrics def test_operators_expected_parameters(sample_context: ExecutionContext) -> None: """Test operators command wrong parameters.""" with pytest.raises(Exception, match="Model is not provided"): - operators(sample_context, "ethos-u55-256") + check(sample_context, "ethos-u55-256") def test_performance_unknown_target( @@ -35,93 +34,45 @@ def test_performance_unknown_target( ) -> None: """Test that command should fail if unknown target passed.""" with pytest.raises(Exception, match="Unable to find target profile unknown"): - performance( - sample_context, model=str(test_tflite_model), target_profile="unknown" + check( + sample_context, + model=str(test_tflite_model), + target_profile="unknown", + performance=True, ) @pytest.mark.parametrize( - "target_profile, optimization_type, optimization_target, expected_error", + "target_profile, pruning, clustering, pruning_target, clustering_target", [ - [ - "ethos-u55-256", - None, - "0.5", - pytest.raises(Exception, match="Optimization type is not provided"), - ], - [ - "ethos-u65-512", - "unknown", - "16", - pytest.raises(Exception, match="Unsupported optimization type: unknown"), - ], - [ - "ethos-u55-256", - "pruning", - None, - pytest.raises(Exception, match="Optimization target is not provided"), - ], - [ - "ethos-u65-512", - "clustering", - None, - pytest.raises(Exception, match="Optimization target is not provided"), - ], - [ - "unknown", - "clustering", - "16", - pytest.raises(Exception, match="Unable to find target profile unknown"), - ], - ], -) -def test_opt_expected_parameters( - sample_context: ExecutionContext, - target_profile: str, - monkeypatch: pytest.MonkeyPatch, - optimization_type: str, - optimization_target: str, - expected_error: Any, - test_keras_model: Path, -) -> None: - """Test that command should fail if no or unknown optimization type provided.""" - mock_performance_estimation(monkeypatch) - - with expected_error: - optimization( - ctx=sample_context, - target_profile=target_profile, - model=str(test_keras_model), - optimization_type=optimization_type, - optimization_target=optimization_target, - ) - - -@pytest.mark.parametrize( - "target_profile, optimization_type, optimization_target", - [ - ["ethos-u55-256", "pruning", "0.5"], - ["ethos-u65-512", "clustering", "32"], - ["ethos-u55-256", "pruning,clustering", "0.5,32"], + ["ethos-u55-256", True, False, 0.5, None], + ["ethos-u65-512", False, True, 0.5, 32], + ["ethos-u55-256", True, True, 0.5, None], + ["ethos-u55-256", False, False, 0.5, None], + ["ethos-u55-256", False, True, "invalid", 32], ], ) def test_opt_valid_optimization_target( target_profile: str, sample_context: ExecutionContext, - optimization_type: str, - optimization_target: str, + pruning: bool, + clustering: bool, + pruning_target: float | None, + clustering_target: int | None, monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, ) -> None: """Test that command should not fail with valid optimization targets.""" mock_performance_estimation(monkeypatch) - optimization( + optimize( ctx=sample_context, target_profile=target_profile, model=str(test_keras_model), - optimization_type=optimization_type, - optimization_target=optimization_target, + pruning=pruning, + clustering=clustering, + pruning_target=pruning_target, + clustering_target=clustering_target, ) diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py index 1a7cb3f..b007052 100644 --- a/tests/test_cli_config.py +++ b/tests/test_cli_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for cli.config module.""" from __future__ import annotations @@ -7,7 +7,7 @@ from unittest.mock import MagicMock import pytest -from mlia.cli.config import get_default_backends +from mlia.cli.config import get_ethos_u_default_backends from mlia.cli.config import is_corstone_backend @@ -29,7 +29,7 @@ from mlia.cli.config import is_corstone_backend ], ], ) -def test_get_default_backends( +def test_get_ethos_u_default_backends( monkeypatch: pytest.MonkeyPatch, available_backends: list[str], expected_default_backends: list[str], @@ -40,7 +40,7 @@ def test_get_default_backends( MagicMock(return_value=available_backends), ) - assert get_default_backends() == expected_default_backends + assert get_ethos_u_default_backends() == expected_default_backends def test_is_corstone_backend() -> None: diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index c8aeebe..8f7e4b0 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the helper classes.""" from __future__ import annotations @@ -28,40 +28,39 @@ class TestCliActionResolver: {}, [ "Note: you will need a Keras model for that.", - "For example: mlia optimization --optimization-type " - "pruning,clustering --optimization-target 0.5,32 " - "/path/to/keras_model", - "For more info: mlia optimization --help", + "For example: mlia optimize /path/to/keras_model " + "--pruning --clustering " + "--pruning-target 0.5 --clustering-target 32", + "For more info: mlia optimize --help", ], ], [ {"model": "model.h5"}, {}, [ - "For example: mlia optimization --optimization-type " - "pruning,clustering --optimization-target 0.5,32 model.h5", - "For more info: mlia optimization --help", + "For example: mlia optimize model.h5 --pruning --clustering " + "--pruning-target 0.5 --clustering-target 32", + "For more info: mlia optimize --help", ], ], [ {"model": "model.h5"}, {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]}, [ - "For more info: mlia optimization --help", + "For more info: mlia optimize --help", "Optimization command: " - "mlia optimization --optimization-type pruning " - "--optimization-target 0.5 model.h5", + "mlia optimize model.h5 --pruning " + "--pruning-target 0.5", ], ], [ {"model": "model.h5", "target_profile": "target_profile"}, {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]}, [ - "For more info: mlia optimization --help", + "For more info: mlia optimize --help", "Optimization command: " - "mlia optimization --optimization-type pruning " - "--optimization-target 0.5 " - "--target-profile target_profile model.h5", + "mlia optimize model.h5 --target-profile target_profile " + "--pruning --pruning-target 0.5", ], ], ], @@ -75,21 +74,12 @@ class TestCliActionResolver: resolver = CLIActionResolver(args) assert resolver.apply_optimizations(**params) == expected_result - @staticmethod - def test_supported_operators_info() -> None: - """Test supported operators info.""" - resolver = CLIActionResolver({}) - assert resolver.supported_operators_info() == [ - "For guidance on supported operators, run: mlia operators " - "--supported-ops-report", - ] - @staticmethod def test_operator_compatibility_details() -> None: """Test operator compatibility details info.""" resolver = CLIActionResolver({}) assert resolver.operator_compatibility_details() == [ - "For more details, run: mlia operators --help" + "For more details, run: mlia check --help" ] @staticmethod @@ -97,7 +87,7 @@ class TestCliActionResolver: """Test optimization details info.""" resolver = CLIActionResolver({}) assert resolver.optimization_details() == [ - "For more info, see: mlia optimization --help" + "For more info, see: mlia optimize --help" ] @staticmethod @@ -108,20 +98,13 @@ class TestCliActionResolver: {}, [], ], - [ - {"model": "model.tflite"}, - [ - "Check the estimated performance by running the " - "following command: ", - "mlia performance model.tflite", - ], - ], [ {"model": "model.tflite", "target_profile": "target_profile"}, [ "Check the estimated performance by running the " "following command: ", - "mlia performance --target-profile target_profile model.tflite", + "mlia check model.tflite " + "--target-profile target_profile --performance", ], ], ], @@ -141,18 +124,11 @@ class TestCliActionResolver: {}, [], ], - [ - {"model": "model.tflite"}, - [ - "Try running the following command to verify that:", - "mlia operators model.tflite", - ], - ], [ {"model": "model.tflite", "target_profile": "target_profile"}, [ "Try running the following command to verify that:", - "mlia operators --target-profile target_profile model.tflite", + "mlia check model.tflite --target-profile target_profile", ], ], ], diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 925f1e4..5a9c0c9 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for main module.""" from __future__ import annotations @@ -19,7 +19,6 @@ from mlia.backend.errors import BackendUnavailableError from mlia.cli.main import backend_main from mlia.cli.main import CommandInfo from mlia.cli.main import main -from mlia.core.context import ExecutionContext from mlia.core.errors import InternalError from tests.utils.logging import clear_loggers @@ -62,35 +61,23 @@ def test_command_info(is_default: bool, expected_command_help: str) -> None: assert command_info.command_help == expected_command_help -def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: +def test_default_command(monkeypatch: pytest.MonkeyPatch) -> None: """Test adding default command.""" - def mock_command( - func_mock: MagicMock, name: str, with_working_dir: bool - ) -> Callable[..., None]: + def mock_command(func_mock: MagicMock, name: str) -> Callable[..., None]: """Mock cli command.""" def sample_cmd_1(*args: Any, **kwargs: Any) -> None: """Sample command.""" func_mock(*args, **kwargs) - def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None: - """Another sample command.""" - func_mock(ctx=ctx, **kwargs) - - ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1 + ret_func = sample_cmd_1 ret_func.__name__ = name - return ret_func # type: ignore + return ret_func - default_command = MagicMock() non_default_command = MagicMock() - def default_command_params(parser: argparse.ArgumentParser) -> None: - """Add parameters for default command.""" - parser.add_argument("--sample") - parser.add_argument("--default_arg", default="123") - def non_default_command_params(parser: argparse.ArgumentParser) -> None: """Add parameters for non default command.""" parser.add_argument("--param") @@ -100,15 +87,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non MagicMock( return_value=[ CommandInfo( - func=mock_command(default_command, "default_command", True), - aliases=["command1"], - opt_groups=[default_command_params], - is_default=True, - ), - CommandInfo( - func=mock_command( - non_default_command, "non_default_command", False - ), + func=mock_command(non_default_command, "non_default_command"), aliases=["command2"], opt_groups=[non_default_command_params], is_default=False, @@ -117,11 +96,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non ), ) - tmp_working_dir = str(tmp_path) - main(["--working-dir", tmp_working_dir, "--sample", "1"]) main(["command2", "--param", "test"]) - - default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123") non_default_command.assert_called_once_with(param="test") @@ -140,134 +115,168 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: "params, expected_call", [ [ - ["operators", "sample_model.tflite"], + ["check", "sample_model.tflite", "--target-profile", "ethos-u55-256"], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.tflite", + compatibility=False, + performance=False, output=None, - supported_ops_report=False, + json=False, + backend=None, ), ], [ - ["ops", "sample_model.tflite"], - call( - ctx=ANY, - target_profile="ethos-u55-256", - model="sample_model.tflite", - output=None, - supported_ops_report=False, - ), - ], - [ - ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"], + ["check", "sample_model.tflite", "--target-profile", "ethos-u55-128"], call( ctx=ANY, target_profile="ethos-u55-128", model="sample_model.tflite", + compatibility=False, + performance=False, output=None, - supported_ops_report=False, + json=False, + backend=None, ), ], [ - ["operators"], - call( - ctx=ANY, - target_profile="ethos-u55-256", - model=None, - output=None, - supported_ops_report=False, - ), - ], - [ - ["operators", "--supported-ops-report"], + [ + "check", + "sample_model.h5", + "--performance", + "--compatibility", + "--target-profile", + "ethos-u55-256", + ], call( ctx=ANY, target_profile="ethos-u55-256", - model=None, + model="sample_model.h5", output=None, - supported_ops_report=True, + json=False, + compatibility=True, + performance=True, + backend=None, ), ], [ [ - "all_tests", + "check", "sample_model.h5", - "--optimization-type", - "pruning", - "--optimization-target", - "0.5", + "--performance", + "--target-profile", + "ethos-u55-256", + "--output", + "result.json", + "--json", ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - optimization_type="pruning", - optimization_target="0.5", - output=None, - evaluate_on=["Vela"], + performance=True, + compatibility=False, + output=Path("result.json"), + json=True, + backend=None, ), ], [ - ["sample_model.h5"], + [ + "check", + "sample_model.h5", + "--performance", + "--target-profile", + "ethos-u55-128", + ], call( ctx=ANY, - target_profile="ethos-u55-256", + target_profile="ethos-u55-128", model="sample_model.h5", - optimization_type="pruning,clustering", - optimization_target="0.5,32", + compatibility=False, + performance=True, output=None, - evaluate_on=["Vela"], + json=False, + backend=None, ), ], [ - ["performance", "sample_model.h5", "--output", "result.json"], + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--clustering", + ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - output="result.json", - evaluate_on=["Vela"], - ), - ], - [ - ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"], - call( - ctx=ANY, - target_profile="ethos-u55-128", - model="sample_model.h5", + pruning=True, + clustering=True, + pruning_target=None, + clustering_target=None, output=None, - evaluate_on=["Vela"], + json=False, + backend=None, ), ], [ - ["optimization", "sample_model.h5"], + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--clustering", + "--pruning-target", + "0.5", + "--clustering-target", + "32", + ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - optimization_type="pruning,clustering", - optimization_target="0.5,32", + pruning=True, + clustering=True, + pruning_target=0.5, + clustering_target=32, output=None, - evaluate_on=["Vela"], + json=False, + backend=None, ), ], [ - ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"], + [ + "optimize", + "sample_model.h5", + "--target-profile", + "ethos-u55-256", + "--pruning", + "--backend", + "some_backend", + ], call( ctx=ANY, target_profile="ethos-u55-256", model="sample_model.h5", - optimization_type="pruning,clustering", - optimization_target="0.5,32", + pruning=True, + clustering=False, + pruning_target=None, + clustering_target=None, output=None, - evaluate_on=["some_backend"], + json=False, + backend=["some_backend"], ), ], [ [ - "operators", + "check", "sample_model.h5", + "--compatibility", "--target-profile", "cortex-a", ], @@ -275,8 +284,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: ctx=ANY, target_profile="cortex-a", model="sample_model.h5", + compatibility=True, + performance=False, output=None, - supported_ops_report=False, + json=False, + backend=None, ), ], ], @@ -287,16 +299,12 @@ def test_commands_execution( """Test calling commands from the main function.""" mock = MagicMock() - monkeypatch.setattr( - "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"]) - ) - monkeypatch.setattr( "mlia.cli.options.get_available_backends", MagicMock(return_value=["Vela", "some_backend"]), ) - for command in ["all_tests", "operators", "performance", "optimization"]: + for command in ["check", "optimize"]: monkeypatch.setattr( f"mlia.cli.main.{command}", wrap_mock_command(mock, getattr(mlia.cli.main, command)), @@ -335,15 +343,15 @@ def test_commands_execution_backend_main( @pytest.mark.parametrize( - "verbose, exc_mock, expected_output", + "debug, exc_mock, expected_output", [ [ True, MagicMock(side_effect=Exception("Error")), [ "Execution finished with error: Error", - f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} " - "for more details", + "Please check the log files in the /tmp/mlia-", + "/logs for more details", ], ], [ @@ -351,8 +359,8 @@ def test_commands_execution_backend_main( MagicMock(side_effect=Exception("Error")), [ "Execution finished with error: Error", - f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} " - "for more details, or enable verbose mode (--verbose)", + "Please check the log files in the /tmp/mlia-", + "/logs for more details, or enable debug mode (--debug)", ], ], [ @@ -389,18 +397,18 @@ def test_commands_execution_backend_main( ], ], ) -def test_verbose_output( +def test_debug_output( monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture, - verbose: bool, + debug: bool, exc_mock: MagicMock, expected_output: list[str], ) -> None: - """Test flag --verbose.""" + """Test flag --debug.""" def command_params(parser: argparse.ArgumentParser) -> None: """Add parameters for non default command.""" - parser.add_argument("--verbose", action="store_true") + parser.add_argument("--debug", action="store_true") def command() -> None: """Run test command.""" @@ -420,8 +428,8 @@ def test_verbose_output( ) params = ["command"] - if verbose: - params.append("--verbose") + if debug: + params.append("--debug") exit_code = main(params) assert exit_code == 1 diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py index d75f7c0..a889a93 100644 --- a/tests/test_cli_options.py +++ b/tests/test_cli_options.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module options.""" from __future__ import annotations @@ -13,14 +13,19 @@ import pytest from mlia.cli.options import add_output_options from mlia.cli.options import get_target_profile_opts from mlia.cli.options import parse_optimization_parameters +from mlia.cli.options import parse_output_parameters +from mlia.core.common import FormattedFilePath @pytest.mark.parametrize( - "optimization_type, optimization_target, expected_error, expected_result", + "pruning, clustering, pruning_target, clustering_target, expected_error," + "expected_result", [ - ( - "pruning", - "0.5", + [ + False, + False, + None, + None, does_not_raise(), [ dict( @@ -29,39 +34,40 @@ from mlia.cli.options import parse_optimization_parameters layers_to_optimize=None, ) ], - ), - ( - "clustering", - "32", + ], + [ + True, + False, + None, + None, does_not_raise(), [ dict( - optimization_type="clustering", - optimization_target=32.0, + optimization_type="pruning", + optimization_target=0.5, layers_to_optimize=None, ) ], - ), - ( - "pruning,clustering", - "0.5,32", + ], + [ + False, + True, + None, + None, does_not_raise(), [ - dict( - optimization_type="pruning", - optimization_target=0.5, - layers_to_optimize=None, - ), dict( optimization_type="clustering", - optimization_target=32.0, + optimization_target=32, layers_to_optimize=None, - ), + ) ], - ), - ( - "pruning, clustering", - "0.5, 32", + ], + [ + True, + True, + None, + None, does_not_raise(), [ dict( @@ -71,50 +77,66 @@ from mlia.cli.options import parse_optimization_parameters ), dict( optimization_type="clustering", - optimization_target=32.0, + optimization_target=32, layers_to_optimize=None, ), ], - ), - ( - "pruning,clustering", - "0.5", - pytest.raises( - Exception, match="Wrong number of optimization targets and types" - ), - None, - ), - ( - "", - "0.5", - pytest.raises(Exception, match="Optimization type is not provided"), + ], + [ + False, + False, + 0.4, None, - ), - ( - "pruning,clustering", - "", - pytest.raises(Exception, match="Optimization target is not provided"), + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.4, + layers_to_optimize=None, + ) + ], + ], + [ + False, + False, None, - ), - ( - "pruning,", - "0.5,abc", + 32, pytest.raises( - Exception, match="Non numeric value for the optimization target" + argparse.ArgumentError, + match="To enable clustering optimization you need to include " + "the `--clustering` flag in your command.", ), None, - ), + ], + [ + False, + True, + None, + 32.2, + does_not_raise(), + [ + dict( + optimization_type="clustering", + optimization_target=32.2, + layers_to_optimize=None, + ) + ], + ], ], ) def test_parse_optimization_parameters( - optimization_type: str, - optimization_target: str, + pruning: bool, + clustering: bool, + pruning_target: float | None, + clustering_target: int | None, expected_error: Any, expected_result: Any, ) -> None: """Test function parse_optimization_parameters.""" with expected_error: - result = parse_optimization_parameters(optimization_type, optimization_target) + result = parse_optimization_parameters( + pruning, clustering, pruning_target, clustering_target + ) assert result == expected_result @@ -155,28 +177,41 @@ def test_output_options(output_parameters: list[str], expected_path: str) -> Non add_output_options(parser) args = parser.parse_args(output_parameters) - assert args.output == expected_path + assert str(args.output) == expected_path @pytest.mark.parametrize( - "output_filename", + "path, json, expected_error, output", [ - "report.txt", - "report.TXT", - "report", - "report.pdf", + [ + None, + True, + pytest.raises( + argparse.ArgumentError, + match=r"To enable JSON output you need to specify the output path. " + r"\(e.g. --output out.json --json\)", + ), + None, + ], + [None, False, does_not_raise(), None], + [ + Path("test_path"), + False, + does_not_raise(), + FormattedFilePath(Path("test_path"), "plain_text"), + ], + [ + Path("test_path"), + True, + does_not_raise(), + FormattedFilePath(Path("test_path"), "json"), + ], ], ) -def test_output_options_bad_parameters( - output_filename: str, capsys: pytest.CaptureFixture +def test_parse_output_parameters( + path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None ) -> None: - """Test that args parsing should fail if format is not supported.""" - parser = argparse.ArgumentParser() - add_output_options(parser) - - with pytest.raises(SystemExit): - parser.parse_args(["--output", output_filename]) - - err_output = capsys.readouterr().err - suffix = Path(output_filename).suffix[1:] - assert f"Unsupported format '{suffix}'" in err_output + """Test parsing for output parameters.""" + with expected_error: + formatted_output = parse_output_parameters(path, json) + assert formatted_output == output diff --git a/tests/test_core_advice_generation.py b/tests/test_core_advice_generation.py index 3d985eb..2e0038f 100644 --- a/tests/test_core_advice_generation.py +++ b/tests/test_core_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module advice_generation.""" from __future__ import annotations @@ -35,17 +35,17 @@ def test_advice_generation() -> None: "category, expected_advice", [ [ - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [Advice(["Good advice!"])], ], [ - AdviceCategory.PERFORMANCE, + {AdviceCategory.PERFORMANCE}, [], ], ], ) def test_advice_category_decorator( - category: AdviceCategory, + category: set[AdviceCategory], expected_advice: list[Advice], sample_context: Context, ) -> None: @@ -54,7 +54,7 @@ def test_advice_category_decorator( class SampleAdviceProducer(FactBasedAdviceProducer): """Sample advice producer.""" - @advice_category(AdviceCategory.OPERATORS) + @advice_category(AdviceCategory.COMPATIBILITY) def produce_advice(self, data_item: DataItem) -> None: """Produce the advice.""" self.add_advice(["Good advice!"]) diff --git a/tests/test_core_context.py b/tests/test_core_context.py index 44eb976..dcdbef3 100644 --- a/tests/test_core_context.py +++ b/tests/test_core_context.py @@ -1,17 +1,53 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the module context.""" +from __future__ import annotations + from pathlib import Path +import pytest + from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.events import DefaultEventPublisher +@pytest.mark.parametrize( + "context_advice_category, expected_enabled_categories", + [ + [ + { + AdviceCategory.COMPATIBILITY, + }, + [AdviceCategory.COMPATIBILITY], + ], + [ + { + AdviceCategory.PERFORMANCE, + }, + [AdviceCategory.PERFORMANCE], + ], + [ + {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE}, + [AdviceCategory.PERFORMANCE, AdviceCategory.COMPATIBILITY], + ], + ], +) +def test_execution_context_category_enabled( + context_advice_category: set[AdviceCategory], + expected_enabled_categories: list[AdviceCategory], +) -> None: + """Test category enabled method of execution context.""" + for category in expected_enabled_categories: + assert ExecutionContext( + advice_category=context_advice_category + ).category_enabled(category) + + def test_execution_context(tmpdir: str) -> None: """Test execution context.""" publisher = DefaultEventPublisher() - category = AdviceCategory.OPERATORS + category = {AdviceCategory.COMPATIBILITY} context = ExecutionContext( advice_category=category, @@ -35,13 +71,13 @@ def test_execution_context(tmpdir: str) -> None: assert str(context) == ( f"ExecutionContext: " f"working_dir={tmpdir}, " - "advice_category=OPERATORS, " + "advice_category={'COMPATIBILITY'}, " "config_parameters={'param': 'value'}, " "verbose=True" ) context_with_default_params = ExecutionContext(working_dir=tmpdir) - assert context_with_default_params.advice_category is AdviceCategory.ALL + assert context_with_default_params.advice_category == {AdviceCategory.COMPATIBILITY} assert context_with_default_params.config_parameters is None assert context_with_default_params.event_handlers is None assert isinstance( @@ -55,7 +91,7 @@ def test_execution_context(tmpdir: str) -> None: expected_str = ( f"ExecutionContext: working_dir={tmpdir}, " - "advice_category=ALL, " + "advice_category={'COMPATIBILITY'}, " "config_parameters=None, " "verbose=False" ) diff --git a/tests/test_core_helpers.py b/tests/test_core_helpers.py index 8577617..03ec3f0 100644 --- a/tests/test_core_helpers.py +++ b/tests/test_core_helpers.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the helper classes.""" from mlia.core.helpers import APIActionResolver @@ -10,7 +10,6 @@ def test_api_action_resolver() -> None: # pylint: disable=use-implicit-booleaness-not-comparison assert helper.apply_optimizations() == [] - assert helper.supported_operators_info() == [] assert helper.check_performance() == [] assert helper.check_operator_compatibility() == [] assert helper.operator_compatibility_details() == [] diff --git a/tests/test_core_mixins.py b/tests/test_core_mixins.py index 3834fb3..47ed815 100644 --- a/tests/test_core_mixins.py +++ b/tests/test_core_mixins.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the module mixins.""" import pytest @@ -36,7 +36,7 @@ class TestParameterResolverMixin: self.context = sample_context self.context.update( - advice_category=AdviceCategory.OPERATORS, + advice_category={AdviceCategory.COMPATIBILITY}, event_handlers=[], config_parameters={"section": {"param": 123}}, ) @@ -83,7 +83,7 @@ class TestParameterResolverMixin: """Init sample object.""" self.context = sample_context self.context.update( - advice_category=AdviceCategory.OPERATORS, + advice_category={AdviceCategory.COMPATIBILITY}, event_handlers=[], config_parameters={"section": ["param"]}, ) diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py index feff5cc..7b26173 100644 --- a/tests/test_core_reporting.py +++ b/tests/test_core_reporting.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for reporting module.""" from __future__ import annotations @@ -13,11 +13,8 @@ from mlia.core.reporting import CyclesCell from mlia.core.reporting import Format from mlia.core.reporting import NestedReport from mlia.core.reporting import ReportItem -from mlia.core.reporting import resolve_output_format from mlia.core.reporting import SingleRow from mlia.core.reporting import Table -from mlia.core.typing import OutputFormat -from mlia.core.typing import PathOrFileLike from mlia.utils.console import remove_ascii_codes @@ -338,20 +335,3 @@ Single row example: alias="simple_row_example", ) wrong_single_row.to_plain_text() - - -@pytest.mark.parametrize( - "output, expected_output_format", - [ - [None, "plain_text"], - ["", "plain_text"], - ["some_file", "plain_text"], - ["some_format.some_ext", "plain_text"], - ["output.json", "json"], - ], -) -def test_resolve_output_format( - output: PathOrFileLike | None, expected_output_format: OutputFormat -) -> None: - """Test function resolve_output_format.""" - assert resolve_output_format(output) == expected_output_format diff --git a/tests/test_target_config.py b/tests/test_target_config.py index 66ebed6..48f0a58 100644 --- a/tests/test_target_config.py +++ b/tests/test_target_config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the backend config module.""" from __future__ import annotations @@ -25,7 +25,7 @@ def test_ip_config() -> None: ( (None, False, True), (None, True, True), - (AdviceCategory.OPERATORS, True, True), + (AdviceCategory.COMPATIBILITY, True, True), (AdviceCategory.OPTIMIZATION, True, False), ), ) @@ -42,7 +42,7 @@ def test_target_info( backend_registry.register( "backend", BackendConfiguration( - [AdviceCategory.OPERATORS], + [AdviceCategory.COMPATIBILITY], [System.CURRENT], BackendType.BUILTIN, ), diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py index 6effe4c..1997c52 100644 --- a/tests/test_target_cortex_a_advice_generation.py +++ b/tests/test_target_cortex_a_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for advice generation.""" from __future__ import annotations @@ -31,7 +31,7 @@ BACKEND_INFO = ( [ [ ModelIsNotCortexACompatible(BACKEND_INFO, {"UNSUPPORTED_OP"}, {}), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -61,7 +61,7 @@ BACKEND_INFO = ( ) }, ), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -93,7 +93,7 @@ BACKEND_INFO = ( ], [ ModelIsCortexACompatible(BACKEND_INFO), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -108,7 +108,7 @@ BACKEND_INFO = ( flex_ops=["flex_op1", "flex_op2"], custom_ops=["custom_op1", "custom_op2"], ), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -142,7 +142,7 @@ BACKEND_INFO = ( ], [ ModelIsNotTFLiteCompatible(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -154,7 +154,7 @@ BACKEND_INFO = ( ], [ ModelHasCustomOperators(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -166,7 +166,7 @@ BACKEND_INFO = ( ], [ TFLiteCompatibilityCheckFailed(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -181,7 +181,7 @@ BACKEND_INFO = ( def test_cortex_a_advice_producer( tmpdir: str, input_data: DataItem, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory], expected_advice: list[Advice], ) -> None: """Test Cortex-A advice producer.""" diff --git a/tests/test_target_ethos_u_advice_generation.py b/tests/test_target_ethos_u_advice_generation.py index 1569592..e93eeba 100644 --- a/tests/test_target_ethos_u_advice_generation.py +++ b/tests/test_target_ethos_u_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Ethos-U advice generation.""" from __future__ import annotations @@ -28,7 +28,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff [ [ AllOperatorsSupportedOnNPU(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, APIActionResolver(), [ Advice( @@ -41,7 +41,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ], [ AllOperatorsSupportedOnNPU(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, CLIActionResolver( { "target_profile": "sample_target", @@ -55,15 +55,15 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff "run completely on NPU.", "Check the estimated performance by running the " "following command: ", - "mlia performance --target-profile sample_target " - "sample_model.tflite", + "mlia check sample_model.tflite --target-profile sample_target " + "--performance", ] ) ], ], [ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, APIActionResolver(), [ Advice( @@ -78,7 +78,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ], [ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, CLIActionResolver({}), [ Advice( @@ -87,15 +87,13 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff "OP1,OP2,OP3.", "Using operators that are supported by the NPU will " "improve performance.", - "For guidance on supported operators, run: mlia operators " - "--supported-ops-report", ] ) ], ], [ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, APIActionResolver(), [ Advice( @@ -110,7 +108,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ], [ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, CLIActionResolver({}), [ Advice( @@ -138,7 +136,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -178,7 +176,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, CLIActionResolver({"model": "sample_model.h5"}), [ Advice( @@ -192,10 +190,10 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff "You can try to push the optimization target higher " "(e.g. pruning: 0.6) " "to check if those results can be further improved.", - "For more info: mlia optimization --help", + "For more info: mlia optimize --help", "Optimization command: " - "mlia optimization --optimization-type pruning " - "--optimization-target 0.6 sample_model.h5", + "mlia optimize sample_model.h5 --pruning " + "--pruning-target 0.6", ] ), Advice( @@ -225,7 +223,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -267,7 +265,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -304,7 +302,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -354,7 +352,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff ), ] ), - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [], # no advice for more than one optimization result ], @@ -364,7 +362,7 @@ def test_ethosu_advice_producer( tmpdir: str, input_data: DataItem, expected_advice: list[Advice], - advice_category: AdviceCategory, + advice_category: set[AdviceCategory] | None, action_resolver: ActionResolver, ) -> None: """Test Ethos-U Advice producer.""" @@ -386,17 +384,17 @@ def test_ethosu_advice_producer( "advice_category, action_resolver, expected_advice", [ [ - AdviceCategory.ALL, + {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE}, None, [], ], [ - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, None, [], ], [ - AdviceCategory.PERFORMANCE, + {AdviceCategory.PERFORMANCE}, APIActionResolver(), [ Advice( @@ -414,31 +412,33 @@ def test_ethosu_advice_producer( ], ], [ - AdviceCategory.PERFORMANCE, - CLIActionResolver({"model": "test_model.h5"}), + {AdviceCategory.PERFORMANCE}, + CLIActionResolver( + {"model": "test_model.h5", "target_profile": "sample_target"} + ), [ Advice( [ "You can improve the inference time by using only operators " "that are supported by the NPU.", "Try running the following command to verify that:", - "mlia operators test_model.h5", + "mlia check test_model.h5 --target-profile sample_target", ] ), Advice( [ "Check if you can improve the performance by applying " "tooling techniques to your model.", - "For example: mlia optimization --optimization-type " - "pruning,clustering --optimization-target 0.5,32 " - "test_model.h5", - "For more info: mlia optimization --help", + "For example: mlia optimize test_model.h5 " + "--pruning --clustering " + "--pruning-target 0.5 --clustering-target 32", + "For more info: mlia optimize --help", ] ), ], ], [ - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, APIActionResolver(), [ Advice( @@ -450,14 +450,14 @@ def test_ethosu_advice_producer( ], ], [ - AdviceCategory.OPTIMIZATION, + {AdviceCategory.OPTIMIZATION}, CLIActionResolver({"model": "test_model.h5"}), [ Advice( [ "For better performance, make sure that all the operators " "of your final TensorFlow Lite model are supported by the NPU.", - "For more details, run: mlia operators --help", + "For more details, run: mlia check --help", ] ) ], @@ -466,7 +466,7 @@ def test_ethosu_advice_producer( ) def test_ethosu_static_advice_producer( tmpdir: str, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory] | None, action_resolver: ActionResolver, expected_advice: list[Advice], ) -> None: diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py index e6ee296..e6028a9 100644 --- a/tests/test_target_registry.py +++ b/tests/test_target_registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the target registry module.""" from __future__ import annotations @@ -26,11 +26,11 @@ def test_target_registry(expected_target: str) -> None: @pytest.mark.parametrize( ("target_name", "expected_advices"), ( - ("Cortex-A", [AdviceCategory.OPERATORS]), + ("Cortex-A", [AdviceCategory.COMPATIBILITY]), ( "Ethos-U55", [ - AdviceCategory.OPERATORS, + AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION, AdviceCategory.PERFORMANCE, ], @@ -38,12 +38,12 @@ def test_target_registry(expected_target: str) -> None: ( "Ethos-U65", [ - AdviceCategory.OPERATORS, + AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION, AdviceCategory.PERFORMANCE, ], ), - ("TOSA", [AdviceCategory.OPERATORS]), + ("TOSA", [AdviceCategory.COMPATIBILITY]), ), ) def test_supported_advice( @@ -72,7 +72,7 @@ def test_supported_backends(target_name: str, expected_backends: list[str]) -> N @pytest.mark.parametrize( ("advice", "expected_targets"), ( - (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]), + (AdviceCategory.COMPATIBILITY, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]), (AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]), (AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]), ), diff --git a/tests/test_target_tosa_advice_generation.py b/tests/test_target_tosa_advice_generation.py index e8e06f8..d5ebbd7 100644 --- a/tests/test_target_tosa_advice_generation.py +++ b/tests/test_target_tosa_advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for advice generation.""" from __future__ import annotations @@ -19,7 +19,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible [ [ ModelIsNotTOSACompatible(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [ Advice( [ @@ -31,7 +31,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible ], [ ModelIsTOSACompatible(), - AdviceCategory.OPERATORS, + {AdviceCategory.COMPATIBILITY}, [Advice(["Model is fully TOSA compatible."])], ], ], @@ -39,7 +39,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible def test_tosa_advice_producer( tmpdir: str, input_data: DataItem, - advice_category: AdviceCategory, + advice_category: set[AdviceCategory], expected_advice: list[Advice], ) -> None: """Test TOSA advice producer.""" diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index fb40735..26f5d29 100644 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """End to end tests for MLIA CLI.""" from __future__ import annotations @@ -20,7 +20,6 @@ from typing import Iterable import pytest from mlia.cli.config import get_available_backends -from mlia.cli.config import get_default_backends from mlia.cli.main import get_commands from mlia.cli.main import get_possible_command_names from mlia.cli.main import init_commands @@ -230,19 +229,19 @@ def check_args(args: list[str], no_skip: bool) -> None: """Check the arguments and skip/fail test cases based on that.""" parser = argparse.ArgumentParser() parser.add_argument( - "--evaluate-on", - help="Backends to use for evaluation (default: %(default)s)", - nargs="*", - default=get_default_backends(), + "--backend", + help="Backends to use for evaluation.", + nargs="+", ) parsed_args, _ = parser.parse_known_args(args) - required_backends = set(parsed_args.evaluate_on) - available_backends = set(get_available_backends()) - missing_backends = required_backends.difference(available_backends) + if parsed_args.backend: + required_backends = set(parsed_args.backend) + available_backends = set(get_available_backends()) + missing_backends = required_backends.difference(available_backends) - if missing_backends and not no_skip: - pytest.skip(f"Missing backend(s): {','.join(missing_backends)}") + if missing_backends and not no_skip: + pytest.skip(f"Missing backend(s): {','.join(missing_backends)}") def get_execution_definitions() -> Generator[list[str], None, None]: -- cgit v1.2.1