aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2022-11-29 13:29:04 +0000
committerRaul Farkas <raul.farkas@arm.com>2023-01-10 10:46:07 +0000
commit5800fc990ed1e36ce7d06670f911fbb12a0ec771 (patch)
tree294605295cd2624ba63e6ad3df335a2a4b2700ab
parentdcd0bd31985c27e1d07333351b26cf8ad12ad1fd (diff)
downloadmlia-5800fc990ed1e36ce7d06670f911fbb12a0ec771.tar.gz
MLIA-650 Implement new CLI changes
Breaking change in the CLI and API: Sub-commands "optimization", "operators", and "performance" were replaced by "check", which incorporates compatibility and performance checks, and "optimize" which is used for optimization. "get_advice" API was adapted to these CLI changes. API changes: * Remove previous advice category "all" that would perform all three operations (when possible). Replace them with the ability to pass a set of the advice categories. * Update api.get_advice method docstring to reflect new changes. * Set default advice category to COMPATIBILITY * Update core.common.AdviceCategory by changing the "OPERATORS" advice category to "COMPATIBILITY" and removing "ALL" enum type. Update all subsequent methods that previously used "OPERATORS" to use "COMPATIBILITY". * Update core.context.ExecutionContext to have "COMPATIBILITY" as default advice_category instead of "ALL". * Remove api.generate_supported_operators_report and all related functions from cli.commands, cli.helpers, cli.main, cli.options, core.helpers * Update tests to reflect new API changes. CLI changes: * Update README.md to contain information on the new CLI * Remove the ability to generate supported operators support from MLIA CLI * Replace `mlia ops` and `mlia perf` with the new `mlia check` command that can be used to perform both operations. * Replace `mlia opt` with the new `mlia optimize` command. * Replace `--evaluate-on` flag with `--backend` flag * Replace `--verbose` flag with `--debug` flag (no behaviour change). * Remove the ability for the user to select MLIA working directory. Create and use a temporary directory in /temp instead. * Change behaviour of `--output` flag to not format the content automatically based on file extension anymore. Instead it will simply redirect to a file. * Add the `--json` flag to specfy that the format of the output should be json. * Add command validators that are used to validate inter-dependent flags (e.g. backend validation based on target_profile). * Add support for selecting built-in backends for both `check` and `optimize` commands. * Add new unit tests and update old ones to test the new CLI changes. * Update RELEASES.md * Update copyright notice Change-Id: Ia6340797c7bee3acbbd26601950e5a16ad5602db
-rw-r--r--README.md74
-rw-r--r--RELEASES.md8
-rw-r--r--src/mlia/api.py57
-rw-r--r--src/mlia/backend/armnn_tflite_delegate/__init__.py4
-rw-r--r--src/mlia/backend/tosa_checker/__init__.py4
-rw-r--r--src/mlia/backend/vela/__init__.py4
-rw-r--r--src/mlia/cli/command_validators.py113
-rw-r--r--src/mlia/cli/commands.py208
-rw-r--r--src/mlia/cli/config.py38
-rw-r--r--src/mlia/cli/helpers.py36
-rw-r--r--src/mlia/cli/main.py81
-rw-r--r--src/mlia/cli/options.py199
-rw-r--r--src/mlia/core/common.py53
-rw-r--r--src/mlia/core/context.py42
-rw-r--r--src/mlia/core/handlers.py10
-rw-r--r--src/mlia/core/helpers.py6
-rw-r--r--src/mlia/core/reporting.py13
-rw-r--r--src/mlia/target/cortex_a/advice_generation.py12
-rw-r--r--src/mlia/target/cortex_a/advisor.py8
-rw-r--r--src/mlia/target/cortex_a/handlers.py6
-rw-r--r--src/mlia/target/ethos_u/advice_generation.py34
-rw-r--r--src/mlia/target/ethos_u/advisor.py14
-rw-r--r--src/mlia/target/ethos_u/handlers.py6
-rw-r--r--src/mlia/target/tosa/advice_generation.py6
-rw-r--r--src/mlia/target/tosa/advisor.py8
-rw-r--r--src/mlia/target/tosa/handlers.py6
-rw-r--r--src/mlia/utils/types.py4
-rw-r--r--tests/test_api.py98
-rw-r--r--tests/test_backend_config.py12
-rw-r--r--tests/test_backend_registry.py8
-rw-r--r--tests/test_cli_command_validators.py167
-rw-r--r--tests/test_cli_commands.py97
-rw-r--r--tests/test_cli_config.py8
-rw-r--r--tests/test_cli_helpers.py62
-rw-r--r--tests/test_cli_main.py228
-rw-r--r--tests/test_cli_options.py179
-rw-r--r--tests/test_core_advice_generation.py10
-rw-r--r--tests/test_core_context.py46
-rw-r--r--tests/test_core_helpers.py3
-rw-r--r--tests/test_core_mixins.py6
-rw-r--r--tests/test_core_reporting.py22
-rw-r--r--tests/test_target_config.py6
-rw-r--r--tests/test_target_cortex_a_advice_generation.py18
-rw-r--r--tests/test_target_ethos_u_advice_generation.py70
-rw-r--r--tests/test_target_registry.py12
-rw-r--r--tests/test_target_tosa_advice_generation.py8
-rw-r--r--tests_e2e/test_e2e.py21
47 files changed, 1155 insertions, 980 deletions
diff --git a/README.md b/README.md
index 501c8c5..d163728 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
<!---
-SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
SPDX-License-Identifier: Apache-2.0
--->
# ML Inference Advisor - Introduction
@@ -96,10 +96,8 @@ mlia [sub-command] [arguments]
Where the following sub-commands are available:
-* ["operators"](#operators-ops): show the model's operator list
-* ["optimization"](#model-optimization-opt): run the specified optimizations
-* ["performance"](#performance-perf): measure the performance of inference on hardware
-* ["all_tests"](#all-tests-all): have a full report
+* ["check"](#check): perform compatibility or performance checks on the model
+* ["optimize"](#optimize): apply specified optimizations
Detailed help about the different sub-commands can be shown like this:
@@ -113,25 +111,27 @@ The following sections go into further detail regarding the usage of MLIA.
This section gives an overview of the available sub-commands for MLIA.
-## **operators** (ops)
+## **check**
-Lists the model's operators with information about their compatibility with the
-specified target.
+### compatibility
+
+Default check that MLIA runs. It lists the model's operators with information
+about their compatibility with the specified target.
*Examples:*
```bash
# List operator compatibility with Ethos-U55 with 256 MAC
-mlia operators --target-profile ethos-u55-256 ~/models/mobilenet_v1_1.0_224_quant.tflite
+mlia check ~/models/mobilenet_v1_1.0_224_quant.tflite --target-profile ethos-u55-256
# List operator compatibility with Cortex-A
-mlia ops --target-profile cortex-a ~/models/mobilenet_v1_1.0_224_quant.tflite
+mlia check ~/models/mobilenet_v1_1.0_224_quant.tflite --target-profile cortex-a
# Get help and further information
-mlia ops --help
+mlia check --help
```
-## **performance** (perf)
+### performance
Estimate the model's performance on the specified target and print out
statistics.
@@ -140,18 +140,21 @@ statistics.
```bash
# Use default parameters
-mlia performance ~/models/mobilenet_v1_1.0_224_quant.tflite
+mlia check ~/models/mobilenet_v1_1.0_224_quant.tflite \
+ --target-profile ethos-u55-256 \
+ --performance
-# Explicitly specify the target profile and backend(s) to use with --evaluate-on
-mlia perf ~/models/ds_cnn_large_fully_quantized_int8.tflite \
- --evaluate-on "Vela" "Corstone-310" \
- --target-profile ethos-u65-512
+# Explicitly specify the target profile and backend(s) to use with --backend
+mlia check ~/models/ds_cnn_large_fully_quantized_int8.tflite \
+ --target-profile ethos-u65-512 \
+ --performance \
+ --backend "Vela" "Corstone-310"
# Get help and further information
-mlia perf --help
+mlia check --help
```
-## **optimization** (opt)
+## **optimize**
This sub-command applies optimizations to a Keras model (.h5 or SavedModel) and
shows the performance improvements compared to the original unoptimized model.
@@ -175,35 +178,20 @@ supported.
```bash
# Custom optimization parameters: pruning=0.6, clustering=16
-mlia optimization \
- --optimization-type pruning,clustering \
- --optimization-target 0.6,16 \
- ~/models/ds_cnn_l.h5
-
-# Get help and further information
-mlia opt --help
-```
-
-## **all_tests** (all)
-
-Combine sub-commands described above to generate a full report of the input
-model with all information available for the specified target. E.g. for Ethos-U
-this combines sub-commands *operators* and *optimization*. Therefore most
-command line arguments are shared with other sub-commands.
-
-*Examples:*
-
-```bash
-# Create full report and save it as JSON file
-mlia all_tests --output ./report.json ~/models/ds_cnn_l.h5
+mlia optimize ~/models/ds_cnn_l.h5 \
+ --target-profile ethos-u55-256 \
+ --pruning \
+ --pruning-target 0.6 \
+ --clustering \
+ --clustering-target 16
# Get help and further information
-mlia all --help
+mlia optimize --help
```
# Target profiles
-Most sub-commands accept the name of a target profile as input parameter. The
+All sub-commands require the name of a target profile as input parameter. The
profiles currently available are described in the following sections.
The support of the above sub-commands for different targets is provided via
@@ -232,7 +220,7 @@ attributes:
Example:
```bash
-mlia perf --target-profile ethos-u65-512 ~/model.tflite
+mlia check ~/model.tflite --target-profile ethos-u65-512 --performance
```
Ethos-U is supported by these backends:
diff --git a/RELEASES.md b/RELEASES.md
index 7f4c752..4d04c89 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,5 +1,5 @@
<!---
-SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
SPDX-License-Identifier: Apache-2.0
--->
# MLIA Releases
@@ -16,6 +16,12 @@ scheme.
of Arm® Limited (or its subsidiaries) in the U.S. and/or elsewhere.
* TensorFlow™ is a trademark of Google® LLC.
+## Release 0.6.0
+
+### Interface changes
+
+* **Breaking change:** Implement new CLI changes (MLIA-650)
+
## Release 0.5.0
### Feature changes
diff --git a/src/mlia/api.py b/src/mlia/api.py
index c7be9ec..2cabf37 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
from __future__ import annotations
@@ -6,18 +6,14 @@ from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
-from typing import Literal
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
+from mlia.core.common import FormattedFilePath
from mlia.core.context import ExecutionContext
-from mlia.core.typing import PathOrFileLike
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
-from mlia.target.cortex_a.operators import report as cortex_a_report
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
-from mlia.target.ethos_u.operators import report as ethos_u_report
from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
-from mlia.target.tosa.operators import report as tosa_report
from mlia.utils.filesystem import get_target
logger = logging.getLogger(__name__)
@@ -26,10 +22,9 @@ logger = logging.getLogger(__name__)
def get_advice(
target_profile: str,
model: str | Path,
- category: Literal["all", "operators", "performance", "optimization"] = "all",
+ category: set[str],
optimization_targets: list[dict[str, Any]] | None = None,
- working_dir: str | Path = "mlia_output",
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
) -> None:
@@ -42,17 +37,13 @@ def get_advice(
:param target_profile: target profile identifier
:param model: path to the NN model
- :param category: category of the advice. MLIA supports four categories:
- "all", "operators", "performance", "optimization". If not provided
- category "all" is used by default.
+ :param category: set of categories of the advice. MLIA supports three categories:
+ "compatibility", "performance", "optimization". If not provided
+ category "compatibility" is used by default.
:param optimization_targets: optional model optimization targets that
- could be used for generating advice in categories
- "all" and "optimization."
- :param working_dir: path to the directory that will be used for storing
- intermediate files during execution (e.g. converted models)
- :param output: path to the report file. If provided MLIA will save
- report in this location. Format of the report automatically
- detected based on file extension.
+ could be used for generating advice in "optimization" category.
+ :param output: path to the report file. If provided, MLIA will save
+ report in this location.
:param context: optional parameter which represents execution context,
could be used for advanced use cases
:param backends: A list of backends that should be used for the given
@@ -63,13 +54,14 @@ def get_advice(
Getting the advice for the provided target profile and the model
- >>> get_advice("ethos-u55-256", "path/to/the/model")
+ >>> get_advice("ethos-u55-256", "path/to/the/model",
+ {"optimization", "compatibility"})
Getting the advice for the category "performance" and save result report in file
"report.json"
- >>> get_advice("ethos-u55-256", "path/to/the/model", "performance",
- output="report.json")
+ >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"},
+ output=FormattedFilePath("report.json")
"""
advice_category = AdviceCategory.from_string(category)
@@ -78,10 +70,7 @@ def get_advice(
context.advice_category = advice_category
if context is None:
- context = ExecutionContext(
- advice_category=advice_category,
- working_dir=working_dir,
- )
+ context = ExecutionContext(advice_category=advice_category)
advisor = get_advisor(
context,
@@ -99,7 +88,7 @@ def get_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
@@ -123,17 +112,3 @@ def get_advisor(
output,
**extra_args,
)
-
-
-def generate_supported_operators_report(target_profile: str) -> None:
- """Generate a supported operators report based on given target profile."""
- generators_map = {
- "ethos-u55": ethos_u_report,
- "ethos-u65": ethos_u_report,
- "cortex-a": cortex_a_report,
- "tosa": tosa_report,
- }
-
- target = get_target(target_profile)
-
- generators_map[target]()
diff --git a/src/mlia/backend/armnn_tflite_delegate/__init__.py b/src/mlia/backend/armnn_tflite_delegate/__init__.py
index 6d5af42..ccb7e38 100644
--- a/src/mlia/backend/armnn_tflite_delegate/__init__.py
+++ b/src/mlia/backend/armnn_tflite_delegate/__init__.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Arm NN TensorFlow Lite delegate backend module."""
from mlia.backend.config import BackendConfiguration
@@ -9,7 +9,7 @@ from mlia.core.common import AdviceCategory
registry.register(
"ArmNNTFLiteDelegate",
BackendConfiguration(
- supported_advice=[AdviceCategory.OPERATORS],
+ supported_advice=[AdviceCategory.COMPATIBILITY],
supported_systems=None,
backend_type=BackendType.BUILTIN,
),
diff --git a/src/mlia/backend/tosa_checker/__init__.py b/src/mlia/backend/tosa_checker/__init__.py
index 19fc8be..c06a122 100644
--- a/src/mlia/backend/tosa_checker/__init__.py
+++ b/src/mlia/backend/tosa_checker/__init__.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA checker backend module."""
from mlia.backend.config import BackendConfiguration
@@ -10,7 +10,7 @@ from mlia.core.common import AdviceCategory
registry.register(
"TOSA-Checker",
BackendConfiguration(
- supported_advice=[AdviceCategory.OPERATORS],
+ supported_advice=[AdviceCategory.COMPATIBILITY],
supported_systems=[System.LINUX_AMD64],
backend_type=BackendType.WHEEL,
),
diff --git a/src/mlia/backend/vela/__init__.py b/src/mlia/backend/vela/__init__.py
index 38a623e..68fbcba 100644
--- a/src/mlia/backend/vela/__init__.py
+++ b/src/mlia/backend/vela/__init__.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Vela backend module."""
from mlia.backend.config import BackendConfiguration
@@ -11,7 +11,7 @@ registry.register(
"Vela",
BackendConfiguration(
supported_advice=[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.PERFORMANCE,
AdviceCategory.OPTIMIZATION,
],
diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py
new file mode 100644
index 0000000..1974a1d
--- /dev/null
+++ b/src/mlia/cli/command_validators.py
@@ -0,0 +1,113 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI command validators module."""
+from __future__ import annotations
+
+import argparse
+import logging
+import sys
+
+from mlia.cli.config import get_default_backends
+from mlia.target.registry import supported_backends
+from mlia.utils.filesystem import get_target
+
+logger = logging.getLogger(__name__)
+
+
+def validate_backend(
+ target_profile: str, backend: list[str] | None
+) -> list[str] | None:
+ """Validate backend with given target profile.
+
+ This validator checks whether the given target-profile and backend are
+ compatible with each other.
+ It assumes that prior checks where made on the validity of the target-profile.
+ """
+ target_map = {
+ "ethos-u55": "Ethos-U55",
+ "ethos-u65": "Ethos-U65",
+ "cortex-a": "Cortex-A",
+ "tosa": "TOSA",
+ }
+ target = get_target(target_profile)
+
+ if not backend:
+ return get_default_backends()[target]
+
+ compatible_backends = supported_backends(target_map[target])
+
+ nor_backend = list(map(normalize_string, backend))
+ nor_compat_backend = list(map(normalize_string, compatible_backends))
+
+ incompatible_backends = [
+ backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend
+ ]
+ # Throw an error if any unsupported backends are used
+ if incompatible_backends:
+ raise argparse.ArgumentError(
+ None,
+ f"{', '.join(incompatible_backends)} backend not supported "
+ f"with target-profile {target_profile}.",
+ )
+ return backend
+
+
+def validate_check_target_profile(target_profile: str, category: set[str]) -> None:
+ """Validate whether advice category is compatible with the provided target_profile.
+
+ This validator function raises warnings if any desired advice category is not
+ compatible with the selected target profile. If no operation can be
+ performed as a result of the validation, MLIA exits with error code 0.
+ """
+ incompatible_targets_performance: list[str] = ["tosa", "cortex-a"]
+ incompatible_targets_compatibility: list[str] = []
+
+ # Check which check operation should be performed
+ try_performance = "performance" in category
+ try_compatibility = "compatibility" in category
+
+ # Cross check which of the desired operations can be performed on given
+ # target-profile
+ do_performance = (
+ try_performance and target_profile not in incompatible_targets_performance
+ )
+ do_compatibility = (
+ try_compatibility and target_profile not in incompatible_targets_compatibility
+ )
+
+ # Case: desired operations can be performed with given target profile
+ if (try_performance == do_performance) and (try_compatibility == do_compatibility):
+ return
+
+ warning_message = "\nWARNING: "
+ # Case: performance operation to be skipped
+ if try_performance and not do_performance:
+ warning_message += (
+ "Performance checks skipped as they cannot be "
+ f"performed with target profile {target_profile}."
+ )
+
+ # Case: compatibility operation to be skipped
+ if try_compatibility and not do_compatibility:
+ warning_message += (
+ "Compatibility checks skipped as they cannot be "
+ f"performed with target profile {target_profile}."
+ )
+
+ # Case: at least one operation will be performed
+ if do_compatibility or do_performance:
+ logger.warning(warning_message)
+ return
+
+ # Case: no operation will be performed
+ warning_message += " No operation was performed."
+ logger.warning(warning_message)
+ sys.exit(0)
+
+
+def normalize_string(value: str) -> str:
+ """Given a string return the normalized version.
+
+ E.g. Given "ToSa-cHecker" -> "tosachecker"
+ """
+ return value.lower().replace("-", "")
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 09fe9de..d2242ba 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI commands module.
@@ -13,7 +13,7 @@ be configured. Function 'setup_logging' from module
>>> from mlia.cli.logging import setup_logging
>>> setup_logging(verbose=True)
>>> import mlia.cli.commands as mlia
->>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+>>> mlia.check(ExecutionContext(), "ethos-u55-256",
"path/to/model")
"""
from __future__ import annotations
@@ -22,11 +22,12 @@ import logging
from pathlib import Path
from mlia.api import ExecutionContext
-from mlia.api import generate_supported_operators_report
from mlia.api import get_advice
-from mlia.api import PathOrFileLike
+from mlia.cli.command_validators import validate_backend
+from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
+from mlia.cli.options import parse_output_parameters
from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
@@ -34,14 +35,15 @@ logger = logging.getLogger(__name__)
CONFIG = create_section_header("ML Inference Advisor configuration")
-def all_tests(
+def check(
ctx: ExecutionContext,
target_profile: str,
- model: str,
- optimization_type: str = "pruning,clustering",
- optimization_target: str = "0.5,32",
- output: PathOrFileLike | None = None,
- evaluate_on: list[str] | None = None,
+ model: str | None = None,
+ compatibility: bool = False,
+ performance: bool = False,
+ output: Path | None = None,
+ json: bool = False,
+ backend: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -50,8 +52,6 @@ def all_tests(
- converts the input Keras model into TensorFlow Lite format
- checks the model for operator compatibility on the specified device
- - applies optimizations to the model and estimates the resulting performance
- on both the original and the optimized models
- generates a final report on the steps above
- provides advice on how to (possibly) improve the inference performance
@@ -59,140 +59,63 @@ def all_tests(
:param target_profile: target profile identifier. Will load appropriate parameters
from the profile.json file based on this argument.
:param model: path to the Keras model
- :param optimization_type: list of the optimization techniques separated
- by comma, e.g. 'pruning,clustering'
- :param optimization_target: list of the corresponding targets for
- the provided optimization techniques, e.g. '0.5,32'
+ :param compatibility: flag that identifies whether to run compatibility checks
+ :param performance: flag that identifies whether to run performance checks
:param output: path to the file where the report will be saved
- :param evaluate_on: list of the backends to use for evaluation
+ :param backend: list of the backends to use for evaluation
Example:
- Run command for the target profile ethos-u55-256 with two model optimizations
- and save report in json format locally in the file report.json
+ Run command for the target profile ethos-u55-256 to verify both performance
+ and operator compatibility.
>>> from mlia.api import ExecutionContext
>>> from mlia.cli.logging import setup_logging
>>> setup_logging()
- >>> from mlia.cli.commands import all_tests
- >>> all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
- "model.h5", "pruning,clustering", "0.5,32",
+ >>> from mlia.cli.commands import check
+ >>> check(ExecutionContext(), "ethos-u55-256",
+ "model.h5", compatibility=True, performance=True,
output="report.json")
"""
- opt_params = parse_optimization_parameters(
- optimization_type,
- optimization_target,
- )
-
- get_advice(
- target_profile,
- model,
- "all",
- optimization_targets=opt_params,
- output=output,
- context=ctx,
- backends=evaluate_on,
- )
-
-
-def operators(
- ctx: ExecutionContext,
- target_profile: str,
- model: str | None = None,
- output: PathOrFileLike | None = None,
- supported_ops_report: bool = False,
-) -> None:
- """Print the model's operator list.
-
- This command checks the operator compatibility of the input model with
- the specific target profile. Generates a report of the operator placement
- (NPU or CPU fallback) and advice on how to improve it (if necessary).
-
- :param ctx: execution context
- :param target_profile: target profile identifier. Will load appropriate parameters
- from the profile.json file based on this argument.
- :param model: path to the model, which can be TensorFlow Lite or Keras
- :param output: path to the file where the report will be saved
- :param supported_ops_report: if True then generates supported operators
- report in current directory and exits
-
- Example:
- Run command for the target profile ethos-u55-256 and the provided
- TensorFlow Lite model and print report on the standard output
-
- >>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
- >>> setup_logging()
- >>> from mlia.cli.commands import operators
- >>> operators(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
- "model.tflite")
- """
- if supported_ops_report:
- generate_supported_operators_report(target_profile)
- logger.info("Report saved into SUPPORTED_OPS.md")
- return
-
if not model:
raise Exception("Model is not provided")
- get_advice(
- target_profile,
- model,
- "operators",
- output=output,
- context=ctx,
- )
-
+ formatted_output = parse_output_parameters(output, json)
-def performance(
- ctx: ExecutionContext,
- target_profile: str,
- model: str,
- output: PathOrFileLike | None = None,
- evaluate_on: list[str] | None = None,
-) -> None:
- """Print the model's performance stats.
-
- This command estimates the inference performance of the input model
- on the specified target profile, and generates a report with advice on how
- to improve it.
-
- :param ctx: execution context
- :param target_profile: target profile identifier. Will load appropriate parameters
- from the profile.json file based on this argument.
- :param model: path to the model, which can be TensorFlow Lite or Keras
- :param output: path to the file where the report will be saved
- :param evaluate_on: list of the backends to use for evaluation
+ # Set category based on checks to perform (i.e. "compatibility" and/or
+ # "performance").
+ # If no check type is specified, "compatibility" is the default category.
+ if compatibility and performance:
+ category = {"compatibility", "performance"}
+ elif performance:
+ category = {"performance"}
+ else:
+ category = {"compatibility"}
- Example:
- Run command for the target profile ethos-u55-256 and
- the provided TensorFlow Lite model and print report on the standard output
+ validate_check_target_profile(target_profile, category)
+ validated_backend = validate_backend(target_profile, backend)
- >>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
- >>> setup_logging()
- >>> from mlia.cli.commands import performance
- >>> performance(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
- "model.tflite")
- """
get_advice(
target_profile,
model,
- "performance",
- output=output,
+ category,
+ output=formatted_output,
context=ctx,
- backends=evaluate_on,
+ backends=validated_backend,
)
-def optimization(
+def optimize( # pylint: disable=too-many-arguments
ctx: ExecutionContext,
target_profile: str,
model: str,
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
layers_to_optimize: list[str] | None = None,
- output: PathOrFileLike | None = None,
- evaluate_on: list[str] | None = None,
+ output: Path | None = None,
+ json: bool = False,
+ backend: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -201,43 +124,54 @@ def optimization(
the inference performance (if possible).
:param ctx: execution context
- :param target: target profile identifier. Will load appropriate parameters
+ :param target_profile: target profile identifier. Will load appropriate parameters
from the profile.json file based on this argument.
:param model: path to the TensorFlow Lite model
- :param optimization_type: list of the optimization techniques separated
- by comma, e.g. 'pruning,clustering'
- :param optimization_target: list of the corresponding targets for
- the provided optimization techniques, e.g. '0.5,32'
+ :param pruning: perform pruning optimization (default if no option specified)
+ :param clustering: perform clustering optimization
+ :param clustering_target: clustering optimization target
+ :param pruning_target: pruning optimization target
:param layers_to_optimize: list of the layers of the model which should be
optimized, if None then all layers are used
:param output: path to the file where the report will be saved
- :param evaluate_on: list of the backends to use for evaluation
+ :param json: set the output format to json
+ :param backend: list of the backends to use for evaluation
Example:
Run command for the target profile ethos-u55-256 and
the provided TensorFlow Lite model and print report on the standard output
>>> from mlia.cli.logging import setup_logging
+ >>> from mlia.api import ExecutionContext
>>> setup_logging()
- >>> from mlia.cli.commands import optimization
- >>> optimization(ExecutionContext(working_dir="mlia_output"),
- target="ethos-u55-256",
- "model.tflite", "pruning", "0.5")
+ >>> from mlia.cli.commands import optimize
+ >>> optimize(ExecutionContext(),
+ target_profile="ethos-u55-256",
+ model="model.tflite", pruning=True,
+ clustering=False, pruning_target=0.5,
+ clustering_target=None)
"""
- opt_params = parse_optimization_parameters(
- optimization_type,
- optimization_target,
- layers_to_optimize=layers_to_optimize,
+ opt_params = (
+ parse_optimization_parameters( # pylint: disable=too-many-function-args
+ pruning,
+ clustering,
+ pruning_target,
+ clustering_target,
+ layers_to_optimize,
+ )
)
+ formatted_output = parse_output_parameters(output, json)
+ validated_backend = validate_backend(target_profile, backend)
+
get_advice(
target_profile,
model,
- "optimization",
+ {"optimization"},
optimization_targets=opt_params,
- output=output,
+ output=formatted_output,
context=ctx,
- backends=evaluate_on,
+ backends=validated_backend,
)
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
index 2d694dc..680b4b6 100644
--- a/src/mlia/cli/config.py
+++ b/src/mlia/cli/config.py
@@ -1,10 +1,13 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Environment configuration functions."""
from __future__ import annotations
import logging
from functools import lru_cache
+from typing import List
+from typing import Optional
+from typing import TypedDict
from mlia.backend.corstone.install import get_corstone_installations
from mlia.backend.install import supported_backends
@@ -14,6 +17,9 @@ from mlia.backend.tosa_checker.install import get_tosa_backend_installation
logger = logging.getLogger(__name__)
+DEFAULT_PRUNING_TARGET = 0.5
+DEFAULT_CLUSTERING_TARGET = 32
+
def get_installation_manager(noninteractive: bool = False) -> InstallationManager:
"""Return installation manager."""
@@ -26,7 +32,7 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage
@lru_cache
def get_available_backends() -> list[str]:
"""Return list of the available backends."""
- available_backends = ["Vela"]
+ available_backends = ["Vela", "tosa-checker", "armnn-tflitedelegate"]
# Add backends using backend manager
manager = get_installation_manager()
@@ -41,9 +47,10 @@ def get_available_backends() -> list[str]:
# List of mutually exclusive Corstone backends ordered by priority
_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
+_NON_ETHOS_U_BACKENDS = ("tosa-checker", "armnn-tflitedelegate")
-def get_default_backends() -> list[str]:
+def get_ethos_u_default_backends() -> list[str]:
"""Get default backends for evaluation."""
backends = get_available_backends()
@@ -57,9 +64,34 @@ def get_default_backends() -> list[str]:
]
break
+ # Filter out non ethos-u backends
+ backends = [x for x in backends if x not in _NON_ETHOS_U_BACKENDS]
return backends
def is_corstone_backend(backend: str) -> bool:
"""Check if the given backend is a Corstone backend."""
return backend in _CORSTONE_EXCLUSIVE_PRIORITY
+
+
+BackendCompatibility = TypedDict(
+ "BackendCompatibility",
+ {
+ "partial-match": bool,
+ "backends": List[str],
+ "default-return": Optional[List[str]],
+ "use-custom-return": bool,
+ "custom-return": Optional[List[str]],
+ },
+)
+
+
+def get_default_backends() -> dict[str, list[str]]:
+ """Return default backends for all targets."""
+ ethos_u_defaults = get_ethos_u_default_backends()
+ return {
+ "ethos-u55": ethos_u_defaults,
+ "ethos-u65": ethos_u_defaults,
+ "tosa": ["tosa-checker"],
+ "cortex-a": ["armnn-tflitedelegate"],
+ }
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index acec837..ac64581 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
from __future__ import annotations
@@ -29,9 +29,9 @@ class CLIActionResolver(ActionResolver):
return [
*keras_note,
- "For example: mlia optimization --optimization-type "
- f"pruning,clustering --optimization-target 0.5,32 {model_path}",
- "For more info: mlia optimization --help",
+ f"For example: mlia optimize {model_path} --pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
]
@staticmethod
@@ -41,14 +41,17 @@ class CLIActionResolver(ActionResolver):
opt_settings: list[OptimizationSettings],
) -> list[str]:
"""Return specific optimization command description."""
- opt_types = ",".join(opt.optimization_type for opt in opt_settings)
- opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
+ opt_types = " ".join("--" + opt.optimization_type for opt in opt_settings)
+ opt_targs_strings = ["--pruning-target", "--clustering-target"]
+ opt_targs = ",".join(
+ f"{opt_targs_strings[i]} {opt.optimization_target}"
+ for i, opt in enumerate(opt_settings)
+ )
return [
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- f"mlia optimization --optimization-type {opt_types} "
- f"--optimization-target {opt_targs}{device_opts} {model_path}",
+ f"mlia optimize {model_path}{device_opts} {opt_types} {opt_targs}",
]
def apply_optimizations(self, **kwargs: Any) -> list[str]:
@@ -65,13 +68,6 @@ class CLIActionResolver(ActionResolver):
return []
- def supported_operators_info(self) -> list[str]:
- """Return command details for generating supported ops report."""
- return [
- "For guidance on supported operators, run: mlia operators "
- "--supported-ops-report",
- ]
-
def check_performance(self) -> list[str]:
"""Return command details for checking performance."""
model_path, device_opts = self._get_model_and_device_opts()
@@ -80,7 +76,7 @@ class CLIActionResolver(ActionResolver):
return [
"Check the estimated performance by running the following command: ",
- f"mlia performance{device_opts} {model_path}",
+ f"mlia check {model_path}{device_opts} --performance",
]
def check_operator_compatibility(self) -> list[str]:
@@ -91,16 +87,16 @@ class CLIActionResolver(ActionResolver):
return [
"Try running the following command to verify that:",
- f"mlia operators{device_opts} {model_path}",
+ f"mlia check {model_path}{device_opts}",
]
def operator_compatibility_details(self) -> list[str]:
"""Return command details for op compatibility."""
- return ["For more details, run: mlia operators --help"]
+ return ["For more details, run: mlia check --help"]
def optimization_details(self) -> list[str]:
"""Return command details for optimization."""
- return ["For more info, see: mlia optimization --help"]
+ return ["For more info, see: mlia optimize --help"]
def _get_model_and_device_opts(
self, separate_device_opts: bool = True
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index ac60308..1102d45 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI main entry point."""
from __future__ import annotations
@@ -8,32 +8,28 @@ import logging
import sys
from functools import partial
from inspect import signature
-from pathlib import Path
from mlia import __version__
from mlia.backend.errors import BackendUnavailableError
from mlia.backend.registry import registry as backend_registry
-from mlia.cli.commands import all_tests
from mlia.cli.commands import backend_install
from mlia.cli.commands import backend_list
from mlia.cli.commands import backend_uninstall
-from mlia.cli.commands import operators
-from mlia.cli.commands import optimization
-from mlia.cli.commands import performance
+from mlia.cli.commands import check
+from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
from mlia.cli.logging import setup_logging
from mlia.cli.options import add_backend_install_options
+from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
-from mlia.cli.options import add_custom_supported_operators_options
+from mlia.cli.options import add_check_category_options
from mlia.cli.options import add_debug_options
-from mlia.cli.options import add_evaluation_options
from mlia.cli.options import add_keras_model_options
+from mlia.cli.options import add_model_options
from mlia.cli.options import add_multi_optimization_options
-from mlia.cli.options import add_optional_tflite_model_options
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
-from mlia.cli.options import add_tflite_model_options
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
@@ -60,50 +56,30 @@ def get_commands() -> list[CommandInfo]:
"""Return commands configuration."""
return [
CommandInfo(
- all_tests,
- ["all"],
- [
- add_target_options,
- add_keras_model_options,
- add_multi_optimization_options,
- add_output_options,
- add_debug_options,
- add_evaluation_options,
- ],
- True,
- ),
- CommandInfo(
- operators,
- ["ops"],
+ check,
+ [],
[
+ add_model_options,
add_target_options,
- add_optional_tflite_model_options,
+ add_backend_options,
+ add_check_category_options,
add_output_options,
- add_custom_supported_operators_options,
add_debug_options,
],
),
CommandInfo(
- performance,
- ["perf"],
- [
- partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
- add_tflite_model_options,
- add_output_options,
- add_debug_options,
- add_evaluation_options,
- ],
- ),
- CommandInfo(
- optimization,
- ["opt"],
+ optimize,
+ [],
[
- partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
add_keras_model_options,
+ partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
+ partial(
+ add_backend_options,
+ backends_to_skip=["tosa-checker", "armnn-tflitedelegate"],
+ ),
add_multi_optimization_options,
add_output_options,
add_debug_options,
- add_evaluation_options,
],
),
]
@@ -184,13 +160,12 @@ def setup_context(
) -> tuple[ExecutionContext, dict]:
"""Set up context and resolve function parameters."""
ctx = ExecutionContext(
- working_dir=args.working_dir,
- verbose="verbose" in args and args.verbose,
+ verbose="debug" in args and args.debug,
action_resolver=CLIActionResolver(vars(args)),
)
# these parameters should not be passed into command function
- skipped_params = ["func", "command", "working_dir", "verbose"]
+ skipped_params = ["func", "command", "debug"]
# pass these parameters only if command expects them
expected_params = [context_var_name]
@@ -219,6 +194,9 @@ def run_command(args: argparse.Namespace) -> int:
try:
logger.info(INFO_MESSAGE)
+ logger.info(
+ "\nThis execution of MLIA uses working directory: %s", ctx.working_dir
+ )
args.func(**func_args)
return 0
except KeyboardInterrupt:
@@ -246,22 +224,19 @@ def run_command(args: argparse.Namespace) -> int:
f"Please check the log files in the {ctx.logs_path} for more details"
)
if not ctx.verbose:
- err_advice_message += ", or enable verbose mode (--verbose)"
+ err_advice_message += ", or enable debug mode (--debug)"
logger.error(err_advice_message)
-
+ finally:
+ logger.info(
+ "This execution of MLIA used working directory: %s", ctx.working_dir
+ )
return 1
def init_common_parser() -> argparse.ArgumentParser:
"""Init common parser."""
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
- parser.add_argument(
- "--working-dir",
- default=f"{Path.cwd() / 'mlia_output'}",
- help="Path to the directory where MLIA will store logs, "
- "models, etc. (default: %(default)s)",
- )
return parser
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 8ea4250..bae6219 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
from __future__ import annotations
@@ -8,37 +8,48 @@ from pathlib import Path
from typing import Any
from typing import Callable
+from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
+from mlia.cli.config import DEFAULT_PRUNING_TARGET
from mlia.cli.config import get_available_backends
-from mlia.cli.config import get_default_backends
from mlia.cli.config import is_corstone_backend
-from mlia.core.reporting import OUTPUT_FORMATS
+from mlia.core.common import FormattedFilePath
from mlia.utils.filesystem import get_supported_profile_names
-from mlia.utils.types import is_number
+
+
+def add_check_category_options(parser: argparse.ArgumentParser) -> None:
+ """Add check category type options."""
+ parser.add_argument(
+ "--performance", action="store_true", help="Perform performance checks."
+ )
+
+ parser.add_argument(
+ "--compatibility",
+ action="store_true",
+ help="Perform compatibility checks. (default)",
+ )
def add_target_options(
- parser: argparse.ArgumentParser, profiles_to_skip: list[str] | None = None
+ parser: argparse.ArgumentParser,
+ profiles_to_skip: list[str] | None = None,
+ required: bool = True,
) -> None:
"""Add target specific options."""
target_profiles = get_supported_profile_names()
if profiles_to_skip:
target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip]
- default_target_profile = None
- default_help = ""
- if target_profiles:
- default_target_profile = target_profiles[0]
- default_help = " (default: %(default)s)"
-
target_group = parser.add_argument_group("target options")
target_group.add_argument(
+ "-t",
"--target-profile",
choices=target_profiles,
- default=default_target_profile,
+ required=required,
+ default="",
help="Target profile that will set the target options "
"such as target, mac value, memory mode, etc. "
- f"For the values associated with each target profile "
- f" please refer to the documentation {default_help}.",
+ "For the values associated with each target profile "
+ "please refer to the documentation.",
)
@@ -47,59 +58,47 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
multi_optimization_group = parser.add_argument_group("optimization options")
multi_optimization_group.add_argument(
- "--optimization-type",
- default="pruning,clustering",
- help="List of the optimization types separated by comma (default: %(default)s)",
+ "--pruning", action="store_true", help="Apply pruning optimization."
)
+
multi_optimization_group.add_argument(
- "--optimization-target",
- default="0.5,32",
- help="""List of the optimization targets separated by comma,
- (for pruning this is sparsity between (0,1),
- for clustering this is the number of clusters (positive integer))
- (default: %(default)s)""",
+ "--clustering", action="store_true", help="Apply clustering optimization."
)
+ multi_optimization_group.add_argument(
+ "--pruning-target",
+ type=float,
+ help="Sparsity to be reached during optimization "
+ f"(default: {DEFAULT_PRUNING_TARGET})",
+ )
-def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None:
- """Add optional model specific options."""
- model_group = parser.add_argument_group("TensorFlow Lite model options")
- # make model parameter optional
- model_group.add_argument(
- "model", nargs="?", help="TensorFlow Lite model (optional)"
+ multi_optimization_group.add_argument(
+ "--clustering-target",
+ type=int,
+ help="Number of clusters to reach during optimization "
+ f"(default: {DEFAULT_CLUSTERING_TARGET})",
)
-def add_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
- model_group = parser.add_argument_group("TensorFlow Lite model options")
- model_group.add_argument("model", help="TensorFlow Lite model")
+ parser.add_argument("model", help="TensorFlow Lite model or Keras model")
def add_output_options(parser: argparse.ArgumentParser) -> None:
"""Add output specific options."""
- valid_extensions = OUTPUT_FORMATS
-
- def check_extension(filename: str) -> str:
- """Check extension of the provided file."""
- suffix = Path(filename).suffix
- if suffix.startswith("."):
- suffix = suffix[1:]
-
- if suffix.lower() not in valid_extensions:
- parser.error(f"Unsupported format '{suffix}'")
-
- return filename
-
output_group = parser.add_argument_group("output options")
output_group.add_argument(
+ "-o",
"--output",
- type=check_extension,
- help=(
- "Name of the file where report will be saved. "
- "Report format is automatically detected based on the file extension. "
- f"Supported formats are: {', '.join(valid_extensions)}"
- ),
+ type=Path,
+ help=("Name of the file where the report will be saved."),
+ )
+
+ output_group.add_argument(
+ "--json",
+ action="store_true",
+ help=("Format to use for the output (requires --output argument to be set)."),
)
@@ -107,7 +106,11 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None:
"""Add debug options."""
debug_group = parser.add_argument_group("debug options")
debug_group.add_argument(
- "--verbose", default=False, action="store_true", help="Produce verbose output"
+ "-d",
+ "--debug",
+ default=False,
+ action="store_true",
+ help="Produce verbose output",
)
@@ -117,20 +120,6 @@ def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
model_group.add_argument("model", help="Keras model")
-def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None:
- """Add custom options for the command 'operators'."""
- parser.add_argument(
- "--supported-ops-report",
- action="store_true",
- default=False,
- help=(
- "Generate the SUPPORTED_OPS.md file in the "
- "current working directory and exit "
- "(Ethos-U target profiles only)"
- ),
- )
-
-
def add_backend_install_options(parser: argparse.ArgumentParser) -> None:
"""Add options for the backends configuration."""
@@ -176,10 +165,11 @@ def add_backend_uninstall_options(parser: argparse.ArgumentParser) -> None:
)
-def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
+def add_backend_options(
+ parser: argparse.ArgumentParser, backends_to_skip: list[str] | None = None
+) -> None:
"""Add evaluation options."""
available_backends = get_available_backends()
- default_backends = get_default_backends()
def only_one_corstone_checker() -> Callable:
"""
@@ -202,41 +192,70 @@ def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
return check
- evaluation_group = parser.add_argument_group("evaluation options")
+ # Remove backends to skip
+ if backends_to_skip:
+ available_backends = [
+ x for x in available_backends if x not in backends_to_skip
+ ]
+
+ evaluation_group = parser.add_argument_group("backend options")
evaluation_group.add_argument(
- "--evaluate-on",
- help="Backends to use for evaluation (default: %(default)s)",
- nargs="*",
+ "-b",
+ "--backend",
+ help="Backends to use for evaluation.",
+ nargs="+",
choices=available_backends,
- default=default_backends,
type=only_one_corstone_checker(),
)
+def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None:
+ """Parse and return path and file format as FormattedFilePath."""
+ if not path and json:
+ raise argparse.ArgumentError(
+ None,
+ "To enable JSON output you need to specify the output path. "
+ "(e.g. --output out.json --json)",
+ )
+ if not path:
+ return None
+ if json:
+ return FormattedFilePath(path, "json")
+
+ return FormattedFilePath(path, "plain_text")
+
+
def parse_optimization_parameters(
- optimization_type: str,
- optimization_target: str,
- sep: str = ",",
+ pruning: bool = False,
+ clustering: bool = False,
+ pruning_target: float | None = None,
+ clustering_target: int | None = None,
layers_to_optimize: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
- if not optimization_type:
- raise Exception("Optimization type is not provided")
+ opt_types = []
+ opt_targets = []
- if not optimization_target:
- raise Exception("Optimization target is not provided")
+ if clustering_target and not clustering:
+ raise argparse.ArgumentError(
+ None,
+ "To enable clustering optimization you need to include the "
+ "`--clustering` flag in your command.",
+ )
- opt_types = optimization_type.split(sep)
- opt_targets = optimization_target.split(sep)
+ if not pruning_target:
+ pruning_target = DEFAULT_PRUNING_TARGET
- if len(opt_types) != len(opt_targets):
- raise Exception("Wrong number of optimization targets and types")
+ if not clustering_target:
+ clustering_target = DEFAULT_CLUSTERING_TARGET
- non_numeric_targets = [
- opt_target for opt_target in opt_targets if not is_number(opt_target)
- ]
- if len(non_numeric_targets) > 0:
- raise Exception("Non numeric value for the optimization target")
+ if (pruning is False and clustering is False) or pruning:
+ opt_types.append("pruning")
+ opt_targets.append(pruning_target)
+
+ if clustering:
+ opt_types.append("clustering")
+ opt_targets.append(clustering_target)
optimizer_params = [
{
@@ -256,7 +275,7 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]:
return []
parser = argparse.ArgumentParser()
- add_target_options(parser)
+ add_target_options(parser, required=False)
args = parser.parse_args([])
params_name = {
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index 6c9dde1..53df001 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Common module.
@@ -13,6 +13,9 @@ from enum import auto
from enum import Flag
from typing import Any
+from mlia.core.typing import OutputFormat
+from mlia.core.typing import PathOrFileLike
+
# This type is used as type alias for the items which are being passed around
# in advisor workflow. There are no restrictions on the type of the
# object. This alias used only to emphasize the nature of the input/output
@@ -20,31 +23,55 @@ from typing import Any
DataItem = Any
+class FormattedFilePath:
+ """Class used to keep track of the format that a path points to."""
+
+ def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None:
+ """Init FormattedFilePath."""
+ self._path = path
+ self._fmt = fmt
+
+ @property
+ def fmt(self) -> OutputFormat:
+ """Return file format."""
+ return self._fmt
+
+ @property
+ def path(self) -> PathOrFileLike:
+ """Return file path."""
+ return self._path
+
+ def __eq__(self, other: object) -> bool:
+ """Check for equality with other objects."""
+ if isinstance(other, FormattedFilePath):
+ return other.fmt == self.fmt and other.path == self.path
+
+ return False
+
+ def __repr__(self) -> str:
+ """Represent object."""
+ return f"FormattedFilePath {self.path=}, {self.fmt=}"
+
+
class AdviceCategory(Flag):
"""Advice category.
Enumeration of advice categories supported by ML Inference Advisor.
"""
- OPERATORS = auto()
+ COMPATIBILITY = auto()
PERFORMANCE = auto()
OPTIMIZATION = auto()
- ALL = (
- # pylint: disable=unsupported-binary-operation
- OPERATORS
- | PERFORMANCE
- | OPTIMIZATION
- # pylint: enable=unsupported-binary-operation
- )
@classmethod
- def from_string(cls, value: str) -> AdviceCategory:
+ def from_string(cls, values: set[str]) -> set[AdviceCategory]:
"""Resolve enum value from string value."""
category_names = [item.name for item in AdviceCategory]
- if not value or value.upper() not in category_names:
- raise Exception(f"Invalid advice category {value}")
+ for advice_value in values:
+ if advice_value.upper() not in category_names:
+ raise Exception(f"Invalid advice category {advice_value}")
- return AdviceCategory[value.upper()]
+ return {AdviceCategory[value.upper()] for value in values}
class NamedEntity(ABC):
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index a4737bb..94aa885 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Context module.
@@ -10,6 +10,7 @@ parameters).
from __future__ import annotations
import logging
+import tempfile
from abc import ABC
from abc import abstractmethod
from pathlib import Path
@@ -54,7 +55,7 @@ class Context(ABC):
@property
@abstractmethod
- def advice_category(self) -> AdviceCategory:
+ def advice_category(self) -> set[AdviceCategory]:
"""Return advice category."""
@property
@@ -71,7 +72,7 @@ class Context(ABC):
def update(
self,
*,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
@@ -79,11 +80,11 @@ class Context(ABC):
def category_enabled(self, category: AdviceCategory) -> bool:
"""Check if category enabled."""
- return category == self.advice_category
+ return category in self.advice_category
def any_category_enabled(self, *categories: AdviceCategory) -> bool:
"""Return true if any category is enabled."""
- return self.advice_category in categories
+ return all(category in self.advice_category for category in categories)
def register_event_handlers(self) -> None:
"""Register event handlers."""
@@ -96,7 +97,7 @@ class ExecutionContext(Context):
def __init__(
self,
*,
- advice_category: AdviceCategory = AdviceCategory.ALL,
+ advice_category: set[AdviceCategory] = None,
config_parameters: Mapping[str, Any] | None = None,
working_dir: str | Path | None = None,
event_handlers: list[EventHandler] | None = None,
@@ -108,7 +109,7 @@ class ExecutionContext(Context):
) -> None:
"""Init execution context.
- :param advice_category: requested advice category
+ :param advice_category: requested advice categories
:param config_parameters: dictionary like object with input parameters
:param working_dir: path to the directory that will be used as a place
to store temporary files, logs, models. If not provided then
@@ -124,13 +125,13 @@ class ExecutionContext(Context):
:param action_resolver: instance of the action resolver that could make
advice actionable
"""
- self._advice_category = advice_category
+ self._advice_category = advice_category or {AdviceCategory.COMPATIBILITY}
self._config_parameters = config_parameters
- self._working_dir_path = Path.cwd()
if working_dir:
self._working_dir_path = Path(working_dir)
- self._working_dir_path.mkdir(exist_ok=True)
+ else:
+ self._working_dir_path = generate_temp_workdir()
self._event_handlers = event_handlers
self._event_publisher = event_publisher or DefaultEventPublisher()
@@ -140,12 +141,17 @@ class ExecutionContext(Context):
self._action_resolver = action_resolver or APIActionResolver()
@property
- def advice_category(self) -> AdviceCategory:
+ def working_dir(self) -> Path:
+ """Return working dir path."""
+ return self._working_dir_path
+
+ @property
+ def advice_category(self) -> set[AdviceCategory]:
"""Return advice category."""
return self._advice_category
@advice_category.setter
- def advice_category(self, advice_category: AdviceCategory) -> None:
+ def advice_category(self, advice_category: set[AdviceCategory]) -> None:
"""Setter for the advice category."""
self._advice_category = advice_category
@@ -194,7 +200,7 @@ class ExecutionContext(Context):
def update(
self,
*,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
@@ -206,7 +212,9 @@ class ExecutionContext(Context):
def __str__(self) -> str:
"""Return string representation."""
category = (
- "<not set>" if self.advice_category is None else self.advice_category.name
+ "<not set>"
+ if self.advice_category is None
+ else {x.name for x in self.advice_category}
)
return (
@@ -215,3 +223,9 @@ class ExecutionContext(Context):
f"config_parameters={self.config_parameters}, "
f"verbose={self.verbose}"
)
+
+
+def generate_temp_workdir() -> Path:
+ """Generate a temporary working dir and returns the path."""
+ working_dir = tempfile.mkdtemp(suffix=None, prefix="mlia-", dir=None)
+ return Path(working_dir)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index a3255ae..6e50934 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handlers module."""
from __future__ import annotations
@@ -9,6 +9,7 @@ from typing import Callable
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
+from mlia.core.common import FormattedFilePath
from mlia.core.events import ActionFinishedEvent
from mlia.core.events import ActionStartedEvent
from mlia.core.events import AdviceStageFinishedEvent
@@ -26,7 +27,6 @@ from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
-from mlia.core.reporting import resolve_output_format
from mlia.core.typing import PathOrFileLike
from mlia.utils.console import create_section_header
@@ -101,12 +101,12 @@ class WorkflowEventsHandler(SystemEventsHandler):
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
) -> None:
"""Init event handler."""
- output_format = resolve_output_format(output)
+ output_format = output.fmt if output else "plain_text"
self.reporter = Reporter(formatter_resolver, output_format)
- self.output = output
+ self.output = output.path if output else None
self.advice: list[Advice] = []
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
index f4a9df6..ed43d04 100644
--- a/src/mlia/core/helpers.py
+++ b/src/mlia/core/helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
# pylint: disable=unused-argument
@@ -14,10 +14,6 @@ class ActionResolver:
"""Return action details for applying optimizations."""
return []
- def supported_operators_info(self) -> list[str]:
- """Return action details for generating supported ops report."""
- return []
-
def check_performance(self) -> list[str]:
"""Return action details for checking performance."""
return []
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index b96a6b5..19644b2 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Reporting module."""
from __future__ import annotations
@@ -639,14 +639,3 @@ def _apply_format_parameters(
return report
return wrapper
-
-
-def resolve_output_format(output: PathOrFileLike | None) -> OutputFormat:
- """Resolve output format based on the output name."""
- if isinstance(output, (str, Path)):
- format_from_filename = Path(output).suffix.lstrip(".")
-
- if format_from_filename in OUTPUT_FORMATS:
- return cast(OutputFormat, format_from_filename)
-
- return "plain_text"
diff --git a/src/mlia/target/cortex_a/advice_generation.py b/src/mlia/target/cortex_a/advice_generation.py
index b68106e..98e8c06 100644
--- a/src/mlia/target/cortex_a/advice_generation.py
+++ b/src/mlia/target/cortex_a/advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A advice generation."""
from functools import singledispatchmethod
@@ -29,7 +29,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
"""Produce advice."""
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_cortex_a_compatible(
self, data_item: ModelIsCortexACompatible
) -> None:
@@ -43,7 +43,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_not_cortex_a_compatible(
self, data_item: ModelIsNotCortexACompatible
) -> None:
@@ -83,7 +83,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_not_tflite_compatible(
self, data_item: ModelIsNotTFLiteCompatible
) -> None:
@@ -127,7 +127,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_tflite_check_failed(
self, _data_item: TFLiteCompatibilityCheckFailed
) -> None:
@@ -140,7 +140,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_has_custom_operators(
self, _data_item: ModelHasCustomOperators
) -> None:
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index 5912e38..b649f0d 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A MLIA module."""
from __future__ import annotations
@@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
+from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
-from mlia.core.typing import PathOrFileLike
from mlia.target.cortex_a.advice_generation import CortexAAdviceProducer
from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer
@@ -38,7 +38,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
collectors: list[DataCollector] = []
- if AdviceCategory.OPERATORS in context.advice_category:
+ if context.category_enabled(AdviceCategory.COMPATIBILITY):
collectors.append(CortexAOperatorCompatibility(model))
return collectors
@@ -67,7 +67,7 @@ def configure_and_get_cortexa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py
index b2d5faa..d6acde5 100644
--- a/src/mlia/target/cortex_a/handlers.py
+++ b/src/mlia/target/cortex_a/handlers.py
@@ -1,13 +1,13 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
from __future__ import annotations
import logging
+from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
-from mlia.core.typing import PathOrFileLike
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.target.cortex_a.events import CortexAAdvisorEventHandler
from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: PathOrFileLike | None = None) -> None:
+ def __init__(self, output: FormattedFilePath | None = None) -> None:
"""Init event handler."""
super().__init__(cortex_a_formatters, output)
diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py
index edd78fd..daae4f4 100644
--- a/src/mlia/target/ethos_u/advice_generation.py
+++ b/src/mlia/target/ethos_u/advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U advice generation."""
from __future__ import annotations
@@ -26,7 +26,7 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
"""Produce advice."""
@produce_advice.register
- @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_cpu_only_ops(self, data_item: HasCPUOnlyOperators) -> None:
"""Advice for CPU only operators."""
cpu_only_ops = ",".join(sorted(set(data_item.cpu_only_ops)))
@@ -40,11 +40,10 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
"Using operators that are supported by the NPU will "
"improve performance.",
]
- + self.context.action_resolver.supported_operators_info()
)
@produce_advice.register
- @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_unsupported_operators(
self, data_item: HasUnsupportedOnNPUOperators
) -> None:
@@ -60,21 +59,25 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_all_operators_supported(
self, _data_item: AllOperatorsSupportedOnNPU
) -> None:
"""Advice if all operators supported."""
- self.add_advice(
- [
- "You don't have any unsupported operators, your model will "
- "run completely on NPU."
- ]
- + self.context.action_resolver.check_performance()
- )
+ advice = [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU."
+ ]
+ if self.context.advice_category != (
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.PERFORMANCE,
+ ):
+ advice += self.context.action_resolver.check_performance()
+
+ self.add_advice(advice)
@produce_advice.register
- @advice_category(AdviceCategory.OPTIMIZATION, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.OPTIMIZATION)
def handle_optimization_results(self, data_item: OptimizationResults) -> None:
"""Advice based on optimization results."""
if not data_item.diffs or len(data_item.diffs) != 1:
@@ -202,5 +205,6 @@ class EthosUStaticAdviceProducer(ContextAwareAdviceProducer):
)
],
}
-
- return advice_per_category.get(self.context.advice_category, [])
+ if len(self.context.advice_category) == 1:
+ return advice_per_category.get(list(self.context.advice_category)[0], [])
+ return []
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index b9d64ff..640c3e1 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U MLIA module."""
from __future__ import annotations
@@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
+from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
-from mlia.core.typing import PathOrFileLike
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
@@ -46,7 +46,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
collectors: list[DataCollector] = []
- if AdviceCategory.OPERATORS in context.advice_category:
+ if context.category_enabled(AdviceCategory.COMPATIBILITY):
collectors.append(EthosUOperatorCompatibility(model, device))
# Performance and optimization are mutually exclusive.
@@ -57,18 +57,18 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
raise Exception(
"Command 'optimization' is not supported for TensorFlow Lite files."
)
- if AdviceCategory.PERFORMANCE in context.advice_category:
+ if context.category_enabled(AdviceCategory.PERFORMANCE):
collectors.append(EthosUPerformance(model, device, backends))
else:
# Keras/SavedModel: Prefer optimization
- if AdviceCategory.OPTIMIZATION in context.advice_category:
+ if context.category_enabled(AdviceCategory.OPTIMIZATION):
optimization_settings = self._get_optimization_settings(context)
collectors.append(
EthosUOptimizationPerformance(
model, device, optimization_settings, backends
)
)
- elif AdviceCategory.PERFORMANCE in context.advice_category:
+ elif context.category_enabled(AdviceCategory.PERFORMANCE):
collectors.append(EthosUPerformance(model, device, backends))
return collectors
@@ -126,7 +126,7 @@ def configure_and_get_ethosu_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Ethos-U advisor."""
diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py
index 84a9554..91f6015 100644
--- a/src/mlia/target/ethos_u/handlers.py
+++ b/src/mlia/target/ethos_u/handlers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
from __future__ import annotations
@@ -6,9 +6,9 @@ from __future__ import annotations
import logging
from mlia.backend.vela.compat import Operators
+from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
-from mlia.core.typing import PathOrFileLike
from mlia.target.ethos_u.events import EthosUAdvisorEventHandler
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: PathOrFileLike | None = None) -> None:
+ def __init__(self, output: FormattedFilePath | None = None) -> None:
"""Init event handler."""
super().__init__(ethos_u_formatters, output)
diff --git a/src/mlia/target/tosa/advice_generation.py b/src/mlia/target/tosa/advice_generation.py
index f531b84..b8b9abf 100644
--- a/src/mlia/target/tosa/advice_generation.py
+++ b/src/mlia/target/tosa/advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advice generation."""
from functools import singledispatchmethod
@@ -19,7 +19,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer):
"""Produce advice."""
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_tosa_compatible(
self, _data_item: ModelIsTOSACompatible
) -> None:
@@ -27,7 +27,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer):
self.add_advice(["Model is fully TOSA compatible."])
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_not_tosa_compatible(
self, _data_item: ModelIsNotTOSACompatible
) -> None:
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 2739dfd..4851113 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advisor."""
from __future__ import annotations
@@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
-from mlia.core.typing import PathOrFileLike
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
@@ -38,7 +38,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
collectors: list[DataCollector] = []
- if AdviceCategory.OPERATORS in context.advice_category:
+ if context.category_enabled(AdviceCategory.COMPATIBILITY):
collectors.append(TOSAOperatorCompatibility(model))
return collectors
@@ -69,7 +69,7 @@ def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
index 863558c..1037ba1 100644
--- a/src/mlia/target/tosa/handlers.py
+++ b/src/mlia/target/tosa/handlers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA Advisor event handlers."""
# pylint: disable=R0801
@@ -7,9 +7,9 @@ from __future__ import annotations
import logging
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
+from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
-from mlia.core.typing import PathOrFileLike
from mlia.target.tosa.events import TOSAAdvisorEventHandler
from mlia.target.tosa.events import TOSAAdvisorStartedEvent
from mlia.target.tosa.reporters import tosa_formatters
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
"""Event handler for TOSA advisor."""
- def __init__(self, output: PathOrFileLike | None = None) -> None:
+ def __init__(self, output: FormattedFilePath | None = None) -> None:
"""Init event handler."""
super().__init__(tosa_formatters, output)
diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py
index ea067b8..0769968 100644
--- a/src/mlia/utils/types.py
+++ b/src/mlia/utils/types.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Types related utility functions."""
from __future__ import annotations
@@ -19,7 +19,7 @@ def is_number(value: str) -> bool:
"""Return true if string contains a number."""
try:
float(value)
- except ValueError:
+ except (ValueError, TypeError):
return False
return True
diff --git a/tests/test_api.py b/tests/test_api.py
index fbc558b..0bbc3ae 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,15 +1,13 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the API functions."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
-from unittest.mock import patch
import pytest
-from mlia.api import generate_supported_operators_report
from mlia.api import get_advice
from mlia.api import get_advisor
from mlia.core.common import AdviceCategory
@@ -22,63 +20,68 @@ from mlia.target.tosa.advisor import TOSAInferenceAdvisor
def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
"""Test getting advice when no target provided."""
with pytest.raises(Exception, match="Target profile is not provided"):
- get_advice(None, test_keras_model, "all") # type: ignore
+ get_advice(None, test_keras_model, {"compatibility"}) # type: ignore
def test_get_advice_wrong_category(test_keras_model: Path) -> None:
"""Test getting advice when wrong advice category provided."""
with pytest.raises(Exception, match="Invalid advice category unknown"):
- get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore
+ get_advice("ethos-u55-256", test_keras_model, {"unknown"})
@pytest.mark.parametrize(
"category, context, expected_category",
[
[
- "all",
+ {"compatibility", "optimization"},
None,
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "optimization",
+ {"optimization"},
None,
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
],
[
- "operators",
+ {"compatibility"},
None,
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
],
[
- "performance",
+ {"performance"},
None,
- AdviceCategory.PERFORMANCE,
+ {AdviceCategory.PERFORMANCE},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
- ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
- AdviceCategory.ALL,
+ {"compatibility", "optimization"},
+ ExecutionContext(
+ advice_category={
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.OPTIMIZATION,
+ }
+ ),
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(config_parameters={"param": "value"}),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
[
- "all",
+ {"compatibility", "optimization"},
ExecutionContext(event_handlers=[MagicMock()]),
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.OPTIMIZATION},
],
],
)
def test_get_advice(
monkeypatch: pytest.MonkeyPatch,
- category: str,
+ category: set[str],
context: ExecutionContext,
expected_category: AdviceCategory,
test_keras_model: Path,
@@ -90,7 +93,7 @@ def test_get_advice(
get_advice(
"ethos-u55-256",
test_keras_model,
- category, # type: ignore
+ category,
context=context,
)
@@ -111,50 +114,3 @@ def test_get_advisor(
tosa_advisor = get_advisor(ExecutionContext(), "tosa", str(test_keras_model))
assert isinstance(tosa_advisor, TOSAInferenceAdvisor)
-
-
-@pytest.mark.parametrize(
- ["target_profile", "required_calls", "exception_msg"],
- [
- [
- "ethos-u55-128",
- "mlia.target.ethos_u.operators.generate_supported_operators_report",
- None,
- ],
- [
- "ethos-u65-256",
- "mlia.target.ethos_u.operators.generate_supported_operators_report",
- None,
- ],
- [
- "tosa",
- None,
- "Generating a supported operators report is not "
- "currently supported with TOSA target profile.",
- ],
- [
- "cortex-a",
- None,
- "Generating a supported operators report is not "
- "currently supported with Cortex-A target profile.",
- ],
- [
- "Unknown",
- None,
- "Unable to find target profile Unknown",
- ],
- ],
-)
-def test_supported_ops_report_generator(
- target_profile: str, required_calls: str | None, exception_msg: str | None
-) -> None:
- """Test supported operators report generator with different target profiles."""
- if exception_msg:
- with pytest.raises(Exception) as exc:
- generate_supported_operators_report(target_profile)
- assert str(exc.value) == exception_msg
-
- if required_calls:
- with patch(required_calls) as mock_method:
- generate_supported_operators_report(target_profile)
- mock_method.assert_called_once()
diff --git a/tests/test_backend_config.py b/tests/test_backend_config.py
index bd50945..700534f 100644
--- a/tests/test_backend_config.py
+++ b/tests/test_backend_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend config module."""
from mlia.backend.config import BackendConfiguration
@@ -20,14 +20,14 @@ def test_system() -> None:
def test_backend_config() -> None:
"""Test the class 'BackendConfiguration'."""
cfg = BackendConfiguration(
- [AdviceCategory.OPERATORS], [System.CURRENT], BackendType.CUSTOM
+ [AdviceCategory.COMPATIBILITY], [System.CURRENT], BackendType.CUSTOM
)
- assert cfg.supported_advice == [AdviceCategory.OPERATORS]
+ assert cfg.supported_advice == [AdviceCategory.COMPATIBILITY]
assert cfg.supported_systems == [System.CURRENT]
assert cfg.type == BackendType.CUSTOM
assert str(cfg)
assert cfg.is_supported()
- assert cfg.is_supported(advice=AdviceCategory.OPERATORS)
+ assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY)
assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE)
assert cfg.is_supported(check_system=True)
assert cfg.is_supported(check_system=False)
@@ -37,6 +37,6 @@ def test_backend_config() -> None:
cfg.supported_systems = [UNSUPPORTED_SYSTEM]
assert not cfg.is_supported(check_system=True)
assert cfg.is_supported(check_system=False)
- assert not cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=True)
- assert cfg.is_supported(advice=AdviceCategory.OPERATORS, check_system=False)
+ assert not cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=True)
+ assert cfg.is_supported(advice=AdviceCategory.COMPATIBILITY, check_system=False)
assert not cfg.is_supported(advice=AdviceCategory.PERFORMANCE, check_system=False)
diff --git a/tests/test_backend_registry.py b/tests/test_backend_registry.py
index 31a20a0..703e699 100644
--- a/tests/test_backend_registry.py
+++ b/tests/test_backend_registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend registry module."""
from __future__ import annotations
@@ -18,7 +18,7 @@ from mlia.core.common import AdviceCategory
(
(
"ArmNNTFLiteDelegate",
- [AdviceCategory.OPERATORS],
+ [AdviceCategory.COMPATIBILITY],
None,
BackendType.BUILTIN,
),
@@ -36,14 +36,14 @@ from mlia.core.common import AdviceCategory
),
(
"TOSA-Checker",
- [AdviceCategory.OPERATORS],
+ [AdviceCategory.COMPATIBILITY],
[System.LINUX_AMD64],
BackendType.WHEEL,
),
(
"Vela",
[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.PERFORMANCE,
AdviceCategory.OPTIMIZATION,
],
diff --git a/tests/test_cli_command_validators.py b/tests/test_cli_command_validators.py
new file mode 100644
index 0000000..13514a5
--- /dev/null
+++ b/tests/test_cli_command_validators.py
@@ -0,0 +1,167 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.command_validators module."""
+from __future__ import annotations
+
+import argparse
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.command_validators import validate_backend
+from mlia.cli.command_validators import validate_check_target_profile
+
+
+@pytest.mark.parametrize(
+ "target_profile, category, expected_warnings, sys_exits",
+ [
+ ["ethos-u55-256", {"compatibility", "performance"}, [], False],
+ ["ethos-u55-256", {"compatibility"}, [], False],
+ ["ethos-u55-256", {"performance"}, [], False],
+ [
+ "tosa",
+ {"compatibility", "performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile tosa."
+ )
+ ],
+ False,
+ ],
+ [
+ "tosa",
+ {"performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile tosa. No operation was performed."
+ )
+ ],
+ True,
+ ],
+ ["tosa", "compatibility", [], False],
+ [
+ "cortex-a",
+ {"performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile cortex-a. "
+ "No operation was performed."
+ )
+ ],
+ True,
+ ],
+ [
+ "cortex-a",
+ {"compatibility", "performance"},
+ [
+ (
+ "\nWARNING: Performance checks skipped as they cannot be "
+ "performed with target profile cortex-a."
+ )
+ ],
+ False,
+ ],
+ ["cortex-a", "compatibility", [], False],
+ ],
+)
+def test_validate_check_target_profile(
+ caplog: pytest.LogCaptureFixture,
+ target_profile: str,
+ category: set[str],
+ expected_warnings: list[str],
+ sys_exits: bool,
+) -> None:
+ """Test outcomes of category dependent target profile validation."""
+ # Capture if program terminates
+ if sys_exits:
+ with pytest.raises(SystemExit) as sys_ex:
+ validate_check_target_profile(target_profile, category)
+ assert sys_ex.value.code == 0
+ return
+
+ validate_check_target_profile(target_profile, category)
+
+ log_records = caplog.records
+ # Get all log records with level 30 (warning level)
+ warning_messages = {x.message for x in log_records if x.levelno == 30}
+ # Ensure the warnings coincide with the expected ones
+ assert warning_messages == set(expected_warnings)
+
+
+@pytest.mark.parametrize(
+ "input_target_profile, input_backends, throws_exception,"
+ "exception_message, output_backends",
+ [
+ [
+ "tosa",
+ ["Vela"],
+ True,
+ "Vela backend not supported with target-profile tosa.",
+ None,
+ ],
+ [
+ "tosa",
+ ["Corstone-300, Vela"],
+ True,
+ "Corstone-300, Vela backend not supported with target-profile tosa.",
+ None,
+ ],
+ [
+ "cortex-a",
+ ["Corstone-310", "tosa-checker"],
+ True,
+ "Corstone-310, tosa-checker backend not supported "
+ "with target-profile cortex-a.",
+ None,
+ ],
+ [
+ "ethos-u55-256",
+ ["tosa-checker", "Corstone-310"],
+ True,
+ "tosa-checker backend not supported with target-profile ethos-u55-256.",
+ None,
+ ],
+ ["tosa", None, False, None, ["tosa-checker"]],
+ ["cortex-a", None, False, None, ["armnn-tflitedelegate"]],
+ ["tosa", ["tosa-checker"], False, None, ["tosa-checker"]],
+ ["cortex-a", ["armnn-tflitedelegate"], False, None, ["armnn-tflitedelegate"]],
+ [
+ "ethos-u55-256",
+ ["Vela", "Corstone-300"],
+ False,
+ None,
+ ["Vela", "Corstone-300"],
+ ],
+ [
+ "ethos-u55-256",
+ None,
+ False,
+ None,
+ ["Vela", "Corstone-300"],
+ ],
+ ],
+)
+def test_validate_backend(
+ monkeypatch: pytest.MonkeyPatch,
+ input_target_profile: str,
+ input_backends: list[str] | None,
+ throws_exception: bool,
+ exception_message: str,
+ output_backends: list[str] | None,
+) -> None:
+ """Test backend validation with target-profiles and backends."""
+ monkeypatch.setattr(
+ "mlia.cli.config.get_available_backends",
+ MagicMock(return_value=["Vela", "Corstone-300"]),
+ )
+
+ if throws_exception:
+ with pytest.raises(argparse.ArgumentError) as err:
+ validate_backend(input_target_profile, input_backends)
+ assert str(err.value.message) == exception_message
+ return
+
+ assert validate_backend(input_target_profile, input_backends) == output_backends
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index aed5c42..03ee9d2 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.commands module."""
from __future__ import annotations
@@ -14,9 +14,8 @@ from mlia.backend.manager import DefaultInstallationManager
from mlia.cli.commands import backend_install
from mlia.cli.commands import backend_list
from mlia.cli.commands import backend_uninstall
-from mlia.cli.commands import operators
-from mlia.cli.commands import optimization
-from mlia.cli.commands import performance
+from mlia.cli.commands import check
+from mlia.cli.commands import optimize
from mlia.core.context import ExecutionContext
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import MemoryUsage
@@ -27,7 +26,7 @@ from mlia.target.ethos_u.performance import PerformanceMetrics
def test_operators_expected_parameters(sample_context: ExecutionContext) -> None:
"""Test operators command wrong parameters."""
with pytest.raises(Exception, match="Model is not provided"):
- operators(sample_context, "ethos-u55-256")
+ check(sample_context, "ethos-u55-256")
def test_performance_unknown_target(
@@ -35,93 +34,45 @@ def test_performance_unknown_target(
) -> None:
"""Test that command should fail if unknown target passed."""
with pytest.raises(Exception, match="Unable to find target profile unknown"):
- performance(
- sample_context, model=str(test_tflite_model), target_profile="unknown"
+ check(
+ sample_context,
+ model=str(test_tflite_model),
+ target_profile="unknown",
+ performance=True,
)
@pytest.mark.parametrize(
- "target_profile, optimization_type, optimization_target, expected_error",
+ "target_profile, pruning, clustering, pruning_target, clustering_target",
[
- [
- "ethos-u55-256",
- None,
- "0.5",
- pytest.raises(Exception, match="Optimization type is not provided"),
- ],
- [
- "ethos-u65-512",
- "unknown",
- "16",
- pytest.raises(Exception, match="Unsupported optimization type: unknown"),
- ],
- [
- "ethos-u55-256",
- "pruning",
- None,
- pytest.raises(Exception, match="Optimization target is not provided"),
- ],
- [
- "ethos-u65-512",
- "clustering",
- None,
- pytest.raises(Exception, match="Optimization target is not provided"),
- ],
- [
- "unknown",
- "clustering",
- "16",
- pytest.raises(Exception, match="Unable to find target profile unknown"),
- ],
- ],
-)
-def test_opt_expected_parameters(
- sample_context: ExecutionContext,
- target_profile: str,
- monkeypatch: pytest.MonkeyPatch,
- optimization_type: str,
- optimization_target: str,
- expected_error: Any,
- test_keras_model: Path,
-) -> None:
- """Test that command should fail if no or unknown optimization type provided."""
- mock_performance_estimation(monkeypatch)
-
- with expected_error:
- optimization(
- ctx=sample_context,
- target_profile=target_profile,
- model=str(test_keras_model),
- optimization_type=optimization_type,
- optimization_target=optimization_target,
- )
-
-
-@pytest.mark.parametrize(
- "target_profile, optimization_type, optimization_target",
- [
- ["ethos-u55-256", "pruning", "0.5"],
- ["ethos-u65-512", "clustering", "32"],
- ["ethos-u55-256", "pruning,clustering", "0.5,32"],
+ ["ethos-u55-256", True, False, 0.5, None],
+ ["ethos-u65-512", False, True, 0.5, 32],
+ ["ethos-u55-256", True, True, 0.5, None],
+ ["ethos-u55-256", False, False, 0.5, None],
+ ["ethos-u55-256", False, True, "invalid", 32],
],
)
def test_opt_valid_optimization_target(
target_profile: str,
sample_context: ExecutionContext,
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
monkeypatch: pytest.MonkeyPatch,
test_keras_model: Path,
) -> None:
"""Test that command should not fail with valid optimization targets."""
mock_performance_estimation(monkeypatch)
- optimization(
+ optimize(
ctx=sample_context,
target_profile=target_profile,
model=str(test_keras_model),
- optimization_type=optimization_type,
- optimization_target=optimization_target,
+ pruning=pruning,
+ clustering=clustering,
+ pruning_target=pruning_target,
+ clustering_target=clustering_target,
)
diff --git a/tests/test_cli_config.py b/tests/test_cli_config.py
index 1a7cb3f..b007052 100644
--- a/tests/test_cli_config.py
+++ b/tests/test_cli_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for cli.config module."""
from __future__ import annotations
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
import pytest
-from mlia.cli.config import get_default_backends
+from mlia.cli.config import get_ethos_u_default_backends
from mlia.cli.config import is_corstone_backend
@@ -29,7 +29,7 @@ from mlia.cli.config import is_corstone_backend
],
],
)
-def test_get_default_backends(
+def test_get_ethos_u_default_backends(
monkeypatch: pytest.MonkeyPatch,
available_backends: list[str],
expected_default_backends: list[str],
@@ -40,7 +40,7 @@ def test_get_default_backends(
MagicMock(return_value=available_backends),
)
- assert get_default_backends() == expected_default_backends
+ assert get_ethos_u_default_backends() == expected_default_backends
def test_is_corstone_backend() -> None:
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index c8aeebe..8f7e4b0 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
from __future__ import annotations
@@ -28,40 +28,39 @@ class TestCliActionResolver:
{},
[
"Note: you will need a Keras model for that.",
- "For example: mlia optimization --optimization-type "
- "pruning,clustering --optimization-target 0.5,32 "
- "/path/to/keras_model",
- "For more info: mlia optimization --help",
+ "For example: mlia optimize /path/to/keras_model "
+ "--pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
],
],
[
{"model": "model.h5"},
{},
[
- "For example: mlia optimization --optimization-type "
- "pruning,clustering --optimization-target 0.5,32 model.h5",
- "For more info: mlia optimization --help",
+ "For example: mlia optimize model.h5 --pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
],
],
[
{"model": "model.h5"},
{"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
[
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- "mlia optimization --optimization-type pruning "
- "--optimization-target 0.5 model.h5",
+ "mlia optimize model.h5 --pruning "
+ "--pruning-target 0.5",
],
],
[
{"model": "model.h5", "target_profile": "target_profile"},
{"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
[
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- "mlia optimization --optimization-type pruning "
- "--optimization-target 0.5 "
- "--target-profile target_profile model.h5",
+ "mlia optimize model.h5 --target-profile target_profile "
+ "--pruning --pruning-target 0.5",
],
],
],
@@ -76,20 +75,11 @@ class TestCliActionResolver:
assert resolver.apply_optimizations(**params) == expected_result
@staticmethod
- def test_supported_operators_info() -> None:
- """Test supported operators info."""
- resolver = CLIActionResolver({})
- assert resolver.supported_operators_info() == [
- "For guidance on supported operators, run: mlia operators "
- "--supported-ops-report",
- ]
-
- @staticmethod
def test_operator_compatibility_details() -> None:
"""Test operator compatibility details info."""
resolver = CLIActionResolver({})
assert resolver.operator_compatibility_details() == [
- "For more details, run: mlia operators --help"
+ "For more details, run: mlia check --help"
]
@staticmethod
@@ -97,7 +87,7 @@ class TestCliActionResolver:
"""Test optimization details info."""
resolver = CLIActionResolver({})
assert resolver.optimization_details() == [
- "For more info, see: mlia optimization --help"
+ "For more info, see: mlia optimize --help"
]
@staticmethod
@@ -109,19 +99,12 @@ class TestCliActionResolver:
[],
],
[
- {"model": "model.tflite"},
- [
- "Check the estimated performance by running the "
- "following command: ",
- "mlia performance model.tflite",
- ],
- ],
- [
{"model": "model.tflite", "target_profile": "target_profile"},
[
"Check the estimated performance by running the "
"following command: ",
- "mlia performance --target-profile target_profile model.tflite",
+ "mlia check model.tflite "
+ "--target-profile target_profile --performance",
],
],
],
@@ -142,17 +125,10 @@ class TestCliActionResolver:
[],
],
[
- {"model": "model.tflite"},
- [
- "Try running the following command to verify that:",
- "mlia operators model.tflite",
- ],
- ],
- [
{"model": "model.tflite", "target_profile": "target_profile"},
[
"Try running the following command to verify that:",
- "mlia operators --target-profile target_profile model.tflite",
+ "mlia check model.tflite --target-profile target_profile",
],
],
],
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 925f1e4..5a9c0c9 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for main module."""
from __future__ import annotations
@@ -19,7 +19,6 @@ from mlia.backend.errors import BackendUnavailableError
from mlia.cli.main import backend_main
from mlia.cli.main import CommandInfo
from mlia.cli.main import main
-from mlia.core.context import ExecutionContext
from mlia.core.errors import InternalError
from tests.utils.logging import clear_loggers
@@ -62,35 +61,23 @@ def test_command_info(is_default: bool, expected_command_help: str) -> None:
assert command_info.command_help == expected_command_help
-def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+def test_default_command(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test adding default command."""
- def mock_command(
- func_mock: MagicMock, name: str, with_working_dir: bool
- ) -> Callable[..., None]:
+ def mock_command(func_mock: MagicMock, name: str) -> Callable[..., None]:
"""Mock cli command."""
def sample_cmd_1(*args: Any, **kwargs: Any) -> None:
"""Sample command."""
func_mock(*args, **kwargs)
- def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None:
- """Another sample command."""
- func_mock(ctx=ctx, **kwargs)
-
- ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1
+ ret_func = sample_cmd_1
ret_func.__name__ = name
- return ret_func # type: ignore
+ return ret_func
- default_command = MagicMock()
non_default_command = MagicMock()
- def default_command_params(parser: argparse.ArgumentParser) -> None:
- """Add parameters for default command."""
- parser.add_argument("--sample")
- parser.add_argument("--default_arg", default="123")
-
def non_default_command_params(parser: argparse.ArgumentParser) -> None:
"""Add parameters for non default command."""
parser.add_argument("--param")
@@ -100,15 +87,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
MagicMock(
return_value=[
CommandInfo(
- func=mock_command(default_command, "default_command", True),
- aliases=["command1"],
- opt_groups=[default_command_params],
- is_default=True,
- ),
- CommandInfo(
- func=mock_command(
- non_default_command, "non_default_command", False
- ),
+ func=mock_command(non_default_command, "non_default_command"),
aliases=["command2"],
opt_groups=[non_default_command_params],
is_default=False,
@@ -117,11 +96,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
),
)
- tmp_working_dir = str(tmp_path)
- main(["--working-dir", tmp_working_dir, "--sample", "1"])
main(["command2", "--param", "test"])
-
- default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123")
non_default_command.assert_called_once_with(param="test")
@@ -140,134 +115,168 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
"params, expected_call",
[
[
- ["operators", "sample_model.tflite"],
+ ["check", "sample_model.tflite", "--target-profile", "ethos-u55-256"],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.tflite",
+ compatibility=False,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
[
- ["ops", "sample_model.tflite"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-256",
- model="sample_model.tflite",
- output=None,
- supported_ops_report=False,
- ),
- ],
- [
- ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
+ ["check", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
call(
ctx=ANY,
target_profile="ethos-u55-128",
model="sample_model.tflite",
+ compatibility=False,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
[
- ["operators"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-256",
- model=None,
- output=None,
- supported_ops_report=False,
- ),
- ],
- [
- ["operators", "--supported-ops-report"],
+ [
+ "check",
+ "sample_model.h5",
+ "--performance",
+ "--compatibility",
+ "--target-profile",
+ "ethos-u55-256",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
- model=None,
+ model="sample_model.h5",
output=None,
- supported_ops_report=True,
+ json=False,
+ compatibility=True,
+ performance=True,
+ backend=None,
),
],
[
[
- "all_tests",
+ "check",
"sample_model.h5",
- "--optimization-type",
- "pruning",
- "--optimization-target",
- "0.5",
+ "--performance",
+ "--target-profile",
+ "ethos-u55-256",
+ "--output",
+ "result.json",
+ "--json",
],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning",
- optimization_target="0.5",
- output=None,
- evaluate_on=["Vela"],
+ performance=True,
+ compatibility=False,
+ output=Path("result.json"),
+ json=True,
+ backend=None,
),
],
[
- ["sample_model.h5"],
+ [
+ "check",
+ "sample_model.h5",
+ "--performance",
+ "--target-profile",
+ "ethos-u55-128",
+ ],
call(
ctx=ANY,
- target_profile="ethos-u55-256",
+ target_profile="ethos-u55-128",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ compatibility=False,
+ performance=True,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["performance", "sample_model.h5", "--output", "result.json"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--clustering",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- output="result.json",
- evaluate_on=["Vela"],
- ),
- ],
- [
- ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-128",
- model="sample_model.h5",
+ pruning=True,
+ clustering=True,
+ pruning_target=None,
+ clustering_target=None,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["optimization", "sample_model.h5"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--clustering",
+ "--pruning-target",
+ "0.5",
+ "--clustering-target",
+ "32",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ pruning=True,
+ clustering=True,
+ pruning_target=0.5,
+ clustering_target=32,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--backend",
+ "some_backend",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ pruning=True,
+ clustering=False,
+ pruning_target=None,
+ clustering_target=None,
output=None,
- evaluate_on=["some_backend"],
+ json=False,
+ backend=["some_backend"],
),
],
[
[
- "operators",
+ "check",
"sample_model.h5",
+ "--compatibility",
"--target-profile",
"cortex-a",
],
@@ -275,8 +284,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
ctx=ANY,
target_profile="cortex-a",
model="sample_model.h5",
+ compatibility=True,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
],
@@ -288,15 +300,11 @@ def test_commands_execution(
mock = MagicMock()
monkeypatch.setattr(
- "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
- )
-
- monkeypatch.setattr(
"mlia.cli.options.get_available_backends",
MagicMock(return_value=["Vela", "some_backend"]),
)
- for command in ["all_tests", "operators", "performance", "optimization"]:
+ for command in ["check", "optimize"]:
monkeypatch.setattr(
f"mlia.cli.main.{command}",
wrap_mock_command(mock, getattr(mlia.cli.main, command)),
@@ -335,15 +343,15 @@ def test_commands_execution_backend_main(
@pytest.mark.parametrize(
- "verbose, exc_mock, expected_output",
+ "debug, exc_mock, expected_output",
[
[
True,
MagicMock(side_effect=Exception("Error")),
[
"Execution finished with error: Error",
- f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
- "for more details",
+ "Please check the log files in the /tmp/mlia-",
+ "/logs for more details",
],
],
[
@@ -351,8 +359,8 @@ def test_commands_execution_backend_main(
MagicMock(side_effect=Exception("Error")),
[
"Execution finished with error: Error",
- f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
- "for more details, or enable verbose mode (--verbose)",
+ "Please check the log files in the /tmp/mlia-",
+ "/logs for more details, or enable debug mode (--debug)",
],
],
[
@@ -389,18 +397,18 @@ def test_commands_execution_backend_main(
],
],
)
-def test_verbose_output(
+def test_debug_output(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture,
- verbose: bool,
+ debug: bool,
exc_mock: MagicMock,
expected_output: list[str],
) -> None:
- """Test flag --verbose."""
+ """Test flag --debug."""
def command_params(parser: argparse.ArgumentParser) -> None:
"""Add parameters for non default command."""
- parser.add_argument("--verbose", action="store_true")
+ parser.add_argument("--debug", action="store_true")
def command() -> None:
"""Run test command."""
@@ -420,8 +428,8 @@ def test_verbose_output(
)
params = ["command"]
- if verbose:
- params.append("--verbose")
+ if debug:
+ params.append("--debug")
exit_code = main(params)
assert exit_code == 1
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index d75f7c0..a889a93 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module options."""
from __future__ import annotations
@@ -13,14 +13,19 @@ import pytest
from mlia.cli.options import add_output_options
from mlia.cli.options import get_target_profile_opts
from mlia.cli.options import parse_optimization_parameters
+from mlia.cli.options import parse_output_parameters
+from mlia.core.common import FormattedFilePath
@pytest.mark.parametrize(
- "optimization_type, optimization_target, expected_error, expected_result",
+ "pruning, clustering, pruning_target, clustering_target, expected_error,"
+ "expected_result",
[
- (
- "pruning",
- "0.5",
+ [
+ False,
+ False,
+ None,
+ None,
does_not_raise(),
[
dict(
@@ -29,39 +34,40 @@ from mlia.cli.options import parse_optimization_parameters
layers_to_optimize=None,
)
],
- ),
- (
- "clustering",
- "32",
+ ],
+ [
+ True,
+ False,
+ None,
+ None,
does_not_raise(),
[
dict(
- optimization_type="clustering",
- optimization_target=32.0,
+ optimization_type="pruning",
+ optimization_target=0.5,
layers_to_optimize=None,
)
],
- ),
- (
- "pruning,clustering",
- "0.5,32",
+ ],
+ [
+ False,
+ True,
+ None,
+ None,
does_not_raise(),
[
dict(
- optimization_type="pruning",
- optimization_target=0.5,
- layers_to_optimize=None,
- ),
- dict(
optimization_type="clustering",
- optimization_target=32.0,
+ optimization_target=32,
layers_to_optimize=None,
- ),
+ )
],
- ),
- (
- "pruning, clustering",
- "0.5, 32",
+ ],
+ [
+ True,
+ True,
+ None,
+ None,
does_not_raise(),
[
dict(
@@ -71,50 +77,66 @@ from mlia.cli.options import parse_optimization_parameters
),
dict(
optimization_type="clustering",
- optimization_target=32.0,
+ optimization_target=32,
layers_to_optimize=None,
),
],
- ),
- (
- "pruning,clustering",
- "0.5",
- pytest.raises(
- Exception, match="Wrong number of optimization targets and types"
- ),
- None,
- ),
- (
- "",
- "0.5",
- pytest.raises(Exception, match="Optimization type is not provided"),
+ ],
+ [
+ False,
+ False,
+ 0.4,
None,
- ),
- (
- "pruning,clustering",
- "",
- pytest.raises(Exception, match="Optimization target is not provided"),
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.4,
+ layers_to_optimize=None,
+ )
+ ],
+ ],
+ [
+ False,
+ False,
None,
- ),
- (
- "pruning,",
- "0.5,abc",
+ 32,
pytest.raises(
- Exception, match="Non numeric value for the optimization target"
+ argparse.ArgumentError,
+ match="To enable clustering optimization you need to include "
+ "the `--clustering` flag in your command.",
),
None,
- ),
+ ],
+ [
+ False,
+ True,
+ None,
+ 32.2,
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.2,
+ layers_to_optimize=None,
+ )
+ ],
+ ],
],
)
def test_parse_optimization_parameters(
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
expected_error: Any,
expected_result: Any,
) -> None:
"""Test function parse_optimization_parameters."""
with expected_error:
- result = parse_optimization_parameters(optimization_type, optimization_target)
+ result = parse_optimization_parameters(
+ pruning, clustering, pruning_target, clustering_target
+ )
assert result == expected_result
@@ -155,28 +177,41 @@ def test_output_options(output_parameters: list[str], expected_path: str) -> Non
add_output_options(parser)
args = parser.parse_args(output_parameters)
- assert args.output == expected_path
+ assert str(args.output) == expected_path
@pytest.mark.parametrize(
- "output_filename",
+ "path, json, expected_error, output",
[
- "report.txt",
- "report.TXT",
- "report",
- "report.pdf",
+ [
+ None,
+ True,
+ pytest.raises(
+ argparse.ArgumentError,
+ match=r"To enable JSON output you need to specify the output path. "
+ r"\(e.g. --output out.json --json\)",
+ ),
+ None,
+ ],
+ [None, False, does_not_raise(), None],
+ [
+ Path("test_path"),
+ False,
+ does_not_raise(),
+ FormattedFilePath(Path("test_path"), "plain_text"),
+ ],
+ [
+ Path("test_path"),
+ True,
+ does_not_raise(),
+ FormattedFilePath(Path("test_path"), "json"),
+ ],
],
)
-def test_output_options_bad_parameters(
- output_filename: str, capsys: pytest.CaptureFixture
+def test_parse_output_parameters(
+ path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None
) -> None:
- """Test that args parsing should fail if format is not supported."""
- parser = argparse.ArgumentParser()
- add_output_options(parser)
-
- with pytest.raises(SystemExit):
- parser.parse_args(["--output", output_filename])
-
- err_output = capsys.readouterr().err
- suffix = Path(output_filename).suffix[1:]
- assert f"Unsupported format '{suffix}'" in err_output
+ """Test parsing for output parameters."""
+ with expected_error:
+ formatted_output = parse_output_parameters(path, json)
+ assert formatted_output == output
diff --git a/tests/test_core_advice_generation.py b/tests/test_core_advice_generation.py
index 3d985eb..2e0038f 100644
--- a/tests/test_core_advice_generation.py
+++ b/tests/test_core_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module advice_generation."""
from __future__ import annotations
@@ -35,17 +35,17 @@ def test_advice_generation() -> None:
"category, expected_advice",
[
[
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[Advice(["Good advice!"])],
],
[
- AdviceCategory.PERFORMANCE,
+ {AdviceCategory.PERFORMANCE},
[],
],
],
)
def test_advice_category_decorator(
- category: AdviceCategory,
+ category: set[AdviceCategory],
expected_advice: list[Advice],
sample_context: Context,
) -> None:
@@ -54,7 +54,7 @@ def test_advice_category_decorator(
class SampleAdviceProducer(FactBasedAdviceProducer):
"""Sample advice producer."""
- @advice_category(AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def produce_advice(self, data_item: DataItem) -> None:
"""Produce the advice."""
self.add_advice(["Good advice!"])
diff --git a/tests/test_core_context.py b/tests/test_core_context.py
index 44eb976..dcdbef3 100644
--- a/tests/test_core_context.py
+++ b/tests/test_core_context.py
@@ -1,17 +1,53 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module context."""
+from __future__ import annotations
+
from pathlib import Path
+import pytest
+
from mlia.core.common import AdviceCategory
from mlia.core.context import ExecutionContext
from mlia.core.events import DefaultEventPublisher
+@pytest.mark.parametrize(
+ "context_advice_category, expected_enabled_categories",
+ [
+ [
+ {
+ AdviceCategory.COMPATIBILITY,
+ },
+ [AdviceCategory.COMPATIBILITY],
+ ],
+ [
+ {
+ AdviceCategory.PERFORMANCE,
+ },
+ [AdviceCategory.PERFORMANCE],
+ ],
+ [
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE},
+ [AdviceCategory.PERFORMANCE, AdviceCategory.COMPATIBILITY],
+ ],
+ ],
+)
+def test_execution_context_category_enabled(
+ context_advice_category: set[AdviceCategory],
+ expected_enabled_categories: list[AdviceCategory],
+) -> None:
+ """Test category enabled method of execution context."""
+ for category in expected_enabled_categories:
+ assert ExecutionContext(
+ advice_category=context_advice_category
+ ).category_enabled(category)
+
+
def test_execution_context(tmpdir: str) -> None:
"""Test execution context."""
publisher = DefaultEventPublisher()
- category = AdviceCategory.OPERATORS
+ category = {AdviceCategory.COMPATIBILITY}
context = ExecutionContext(
advice_category=category,
@@ -35,13 +71,13 @@ def test_execution_context(tmpdir: str) -> None:
assert str(context) == (
f"ExecutionContext: "
f"working_dir={tmpdir}, "
- "advice_category=OPERATORS, "
+ "advice_category={'COMPATIBILITY'}, "
"config_parameters={'param': 'value'}, "
"verbose=True"
)
context_with_default_params = ExecutionContext(working_dir=tmpdir)
- assert context_with_default_params.advice_category is AdviceCategory.ALL
+ assert context_with_default_params.advice_category == {AdviceCategory.COMPATIBILITY}
assert context_with_default_params.config_parameters is None
assert context_with_default_params.event_handlers is None
assert isinstance(
@@ -55,7 +91,7 @@ def test_execution_context(tmpdir: str) -> None:
expected_str = (
f"ExecutionContext: working_dir={tmpdir}, "
- "advice_category=ALL, "
+ "advice_category={'COMPATIBILITY'}, "
"config_parameters=None, "
"verbose=False"
)
diff --git a/tests/test_core_helpers.py b/tests/test_core_helpers.py
index 8577617..03ec3f0 100644
--- a/tests/test_core_helpers.py
+++ b/tests/test_core_helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the helper classes."""
from mlia.core.helpers import APIActionResolver
@@ -10,7 +10,6 @@ def test_api_action_resolver() -> None:
# pylint: disable=use-implicit-booleaness-not-comparison
assert helper.apply_optimizations() == []
- assert helper.supported_operators_info() == []
assert helper.check_performance() == []
assert helper.check_operator_compatibility() == []
assert helper.operator_compatibility_details() == []
diff --git a/tests/test_core_mixins.py b/tests/test_core_mixins.py
index 3834fb3..47ed815 100644
--- a/tests/test_core_mixins.py
+++ b/tests/test_core_mixins.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module mixins."""
import pytest
@@ -36,7 +36,7 @@ class TestParameterResolverMixin:
self.context = sample_context
self.context.update(
- advice_category=AdviceCategory.OPERATORS,
+ advice_category={AdviceCategory.COMPATIBILITY},
event_handlers=[],
config_parameters={"section": {"param": 123}},
)
@@ -83,7 +83,7 @@ class TestParameterResolverMixin:
"""Init sample object."""
self.context = sample_context
self.context.update(
- advice_category=AdviceCategory.OPERATORS,
+ advice_category={AdviceCategory.COMPATIBILITY},
event_handlers=[],
config_parameters={"section": ["param"]},
)
diff --git a/tests/test_core_reporting.py b/tests/test_core_reporting.py
index feff5cc..7b26173 100644
--- a/tests/test_core_reporting.py
+++ b/tests/test_core_reporting.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for reporting module."""
from __future__ import annotations
@@ -13,11 +13,8 @@ from mlia.core.reporting import CyclesCell
from mlia.core.reporting import Format
from mlia.core.reporting import NestedReport
from mlia.core.reporting import ReportItem
-from mlia.core.reporting import resolve_output_format
from mlia.core.reporting import SingleRow
from mlia.core.reporting import Table
-from mlia.core.typing import OutputFormat
-from mlia.core.typing import PathOrFileLike
from mlia.utils.console import remove_ascii_codes
@@ -338,20 +335,3 @@ Single row example:
alias="simple_row_example",
)
wrong_single_row.to_plain_text()
-
-
-@pytest.mark.parametrize(
- "output, expected_output_format",
- [
- [None, "plain_text"],
- ["", "plain_text"],
- ["some_file", "plain_text"],
- ["some_format.some_ext", "plain_text"],
- ["output.json", "json"],
- ],
-)
-def test_resolve_output_format(
- output: PathOrFileLike | None, expected_output_format: OutputFormat
-) -> None:
- """Test function resolve_output_format."""
- assert resolve_output_format(output) == expected_output_format
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
index 66ebed6..48f0a58 100644
--- a/tests/test_target_config.py
+++ b/tests/test_target_config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the backend config module."""
from __future__ import annotations
@@ -25,7 +25,7 @@ def test_ip_config() -> None:
(
(None, False, True),
(None, True, True),
- (AdviceCategory.OPERATORS, True, True),
+ (AdviceCategory.COMPATIBILITY, True, True),
(AdviceCategory.OPTIMIZATION, True, False),
),
)
@@ -42,7 +42,7 @@ def test_target_info(
backend_registry.register(
"backend",
BackendConfiguration(
- [AdviceCategory.OPERATORS],
+ [AdviceCategory.COMPATIBILITY],
[System.CURRENT],
BackendType.BUILTIN,
),
diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py
index 6effe4c..1997c52 100644
--- a/tests/test_target_cortex_a_advice_generation.py
+++ b/tests/test_target_cortex_a_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for advice generation."""
from __future__ import annotations
@@ -31,7 +31,7 @@ BACKEND_INFO = (
[
[
ModelIsNotCortexACompatible(BACKEND_INFO, {"UNSUPPORTED_OP"}, {}),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -61,7 +61,7 @@ BACKEND_INFO = (
)
},
),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -93,7 +93,7 @@ BACKEND_INFO = (
],
[
ModelIsCortexACompatible(BACKEND_INFO),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -108,7 +108,7 @@ BACKEND_INFO = (
flex_ops=["flex_op1", "flex_op2"],
custom_ops=["custom_op1", "custom_op2"],
),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -142,7 +142,7 @@ BACKEND_INFO = (
],
[
ModelIsNotTFLiteCompatible(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -154,7 +154,7 @@ BACKEND_INFO = (
],
[
ModelHasCustomOperators(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -166,7 +166,7 @@ BACKEND_INFO = (
],
[
TFLiteCompatibilityCheckFailed(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -181,7 +181,7 @@ BACKEND_INFO = (
def test_cortex_a_advice_producer(
tmpdir: str,
input_data: DataItem,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
expected_advice: list[Advice],
) -> None:
"""Test Cortex-A advice producer."""
diff --git a/tests/test_target_ethos_u_advice_generation.py b/tests/test_target_ethos_u_advice_generation.py
index 1569592..e93eeba 100644
--- a/tests/test_target_ethos_u_advice_generation.py
+++ b/tests/test_target_ethos_u_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U advice generation."""
from __future__ import annotations
@@ -28,7 +28,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
[
[
AllOperatorsSupportedOnNPU(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
APIActionResolver(),
[
Advice(
@@ -41,7 +41,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
],
[
AllOperatorsSupportedOnNPU(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
CLIActionResolver(
{
"target_profile": "sample_target",
@@ -55,15 +55,15 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
"run completely on NPU.",
"Check the estimated performance by running the "
"following command: ",
- "mlia performance --target-profile sample_target "
- "sample_model.tflite",
+ "mlia check sample_model.tflite --target-profile sample_target "
+ "--performance",
]
)
],
],
[
HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
APIActionResolver(),
[
Advice(
@@ -78,7 +78,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
],
[
HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
CLIActionResolver({}),
[
Advice(
@@ -87,15 +87,13 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
"OP1,OP2,OP3.",
"Using operators that are supported by the NPU will "
"improve performance.",
- "For guidance on supported operators, run: mlia operators "
- "--supported-ops-report",
]
)
],
],
[
HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
APIActionResolver(),
[
Advice(
@@ -110,7 +108,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
],
[
HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
CLIActionResolver({}),
[
Advice(
@@ -138,7 +136,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -178,7 +176,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
CLIActionResolver({"model": "sample_model.h5"}),
[
Advice(
@@ -192,10 +190,10 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
"You can try to push the optimization target higher "
"(e.g. pruning: 0.6) "
"to check if those results can be further improved.",
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- "mlia optimization --optimization-type pruning "
- "--optimization-target 0.6 sample_model.h5",
+ "mlia optimize sample_model.h5 --pruning "
+ "--pruning-target 0.6",
]
),
Advice(
@@ -225,7 +223,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -267,7 +265,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -304,7 +302,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -354,7 +352,7 @@ from mlia.target.ethos_u.data_analysis import PerfMetricDiff
),
]
),
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[], # no advice for more than one optimization result
],
@@ -364,7 +362,7 @@ def test_ethosu_advice_producer(
tmpdir: str,
input_data: DataItem,
expected_advice: list[Advice],
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory] | None,
action_resolver: ActionResolver,
) -> None:
"""Test Ethos-U Advice producer."""
@@ -386,17 +384,17 @@ def test_ethosu_advice_producer(
"advice_category, action_resolver, expected_advice",
[
[
- AdviceCategory.ALL,
+ {AdviceCategory.COMPATIBILITY, AdviceCategory.PERFORMANCE},
None,
[],
],
[
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
None,
[],
],
[
- AdviceCategory.PERFORMANCE,
+ {AdviceCategory.PERFORMANCE},
APIActionResolver(),
[
Advice(
@@ -414,31 +412,33 @@ def test_ethosu_advice_producer(
],
],
[
- AdviceCategory.PERFORMANCE,
- CLIActionResolver({"model": "test_model.h5"}),
+ {AdviceCategory.PERFORMANCE},
+ CLIActionResolver(
+ {"model": "test_model.h5", "target_profile": "sample_target"}
+ ),
[
Advice(
[
"You can improve the inference time by using only operators "
"that are supported by the NPU.",
"Try running the following command to verify that:",
- "mlia operators test_model.h5",
+ "mlia check test_model.h5 --target-profile sample_target",
]
),
Advice(
[
"Check if you can improve the performance by applying "
"tooling techniques to your model.",
- "For example: mlia optimization --optimization-type "
- "pruning,clustering --optimization-target 0.5,32 "
- "test_model.h5",
- "For more info: mlia optimization --help",
+ "For example: mlia optimize test_model.h5 "
+ "--pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
]
),
],
],
[
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
APIActionResolver(),
[
Advice(
@@ -450,14 +450,14 @@ def test_ethosu_advice_producer(
],
],
[
- AdviceCategory.OPTIMIZATION,
+ {AdviceCategory.OPTIMIZATION},
CLIActionResolver({"model": "test_model.h5"}),
[
Advice(
[
"For better performance, make sure that all the operators "
"of your final TensorFlow Lite model are supported by the NPU.",
- "For more details, run: mlia operators --help",
+ "For more details, run: mlia check --help",
]
)
],
@@ -466,7 +466,7 @@ def test_ethosu_advice_producer(
)
def test_ethosu_static_advice_producer(
tmpdir: str,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory] | None,
action_resolver: ActionResolver,
expected_advice: list[Advice],
) -> None:
diff --git a/tests/test_target_registry.py b/tests/test_target_registry.py
index e6ee296..e6028a9 100644
--- a/tests/test_target_registry.py
+++ b/tests/test_target_registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the target registry module."""
from __future__ import annotations
@@ -26,11 +26,11 @@ def test_target_registry(expected_target: str) -> None:
@pytest.mark.parametrize(
("target_name", "expected_advices"),
(
- ("Cortex-A", [AdviceCategory.OPERATORS]),
+ ("Cortex-A", [AdviceCategory.COMPATIBILITY]),
(
"Ethos-U55",
[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.OPTIMIZATION,
AdviceCategory.PERFORMANCE,
],
@@ -38,12 +38,12 @@ def test_target_registry(expected_target: str) -> None:
(
"Ethos-U65",
[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.OPTIMIZATION,
AdviceCategory.PERFORMANCE,
],
),
- ("TOSA", [AdviceCategory.OPERATORS]),
+ ("TOSA", [AdviceCategory.COMPATIBILITY]),
),
)
def test_supported_advice(
@@ -72,7 +72,7 @@ def test_supported_backends(target_name: str, expected_backends: list[str]) -> N
@pytest.mark.parametrize(
("advice", "expected_targets"),
(
- (AdviceCategory.OPERATORS, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]),
+ (AdviceCategory.COMPATIBILITY, ["Cortex-A", "Ethos-U55", "Ethos-U65", "TOSA"]),
(AdviceCategory.OPTIMIZATION, ["Ethos-U55", "Ethos-U65"]),
(AdviceCategory.PERFORMANCE, ["Ethos-U55", "Ethos-U65"]),
),
diff --git a/tests/test_target_tosa_advice_generation.py b/tests/test_target_tosa_advice_generation.py
index e8e06f8..d5ebbd7 100644
--- a/tests/test_target_tosa_advice_generation.py
+++ b/tests/test_target_tosa_advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for advice generation."""
from __future__ import annotations
@@ -19,7 +19,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
[
[
ModelIsNotTOSACompatible(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[
Advice(
[
@@ -31,7 +31,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
],
[
ModelIsTOSACompatible(),
- AdviceCategory.OPERATORS,
+ {AdviceCategory.COMPATIBILITY},
[Advice(["Model is fully TOSA compatible."])],
],
],
@@ -39,7 +39,7 @@ from mlia.target.tosa.data_analysis import ModelIsTOSACompatible
def test_tosa_advice_producer(
tmpdir: str,
input_data: DataItem,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
expected_advice: list[Advice],
) -> None:
"""Test TOSA advice producer."""
diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py
index fb40735..26f5d29 100644
--- a/tests_e2e/test_e2e.py
+++ b/tests_e2e/test_e2e.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""End to end tests for MLIA CLI."""
from __future__ import annotations
@@ -20,7 +20,6 @@ from typing import Iterable
import pytest
from mlia.cli.config import get_available_backends
-from mlia.cli.config import get_default_backends
from mlia.cli.main import get_commands
from mlia.cli.main import get_possible_command_names
from mlia.cli.main import init_commands
@@ -230,19 +229,19 @@ def check_args(args: list[str], no_skip: bool) -> None:
"""Check the arguments and skip/fail test cases based on that."""
parser = argparse.ArgumentParser()
parser.add_argument(
- "--evaluate-on",
- help="Backends to use for evaluation (default: %(default)s)",
- nargs="*",
- default=get_default_backends(),
+ "--backend",
+ help="Backends to use for evaluation.",
+ nargs="+",
)
parsed_args, _ = parser.parse_known_args(args)
- required_backends = set(parsed_args.evaluate_on)
- available_backends = set(get_available_backends())
- missing_backends = required_backends.difference(available_backends)
+ if parsed_args.backend:
+ required_backends = set(parsed_args.backend)
+ available_backends = set(get_available_backends())
+ missing_backends = required_backends.difference(available_backends)
- if missing_backends and not no_skip:
- pytest.skip(f"Missing backend(s): {','.join(missing_backends)}")
+ if missing_backends and not no_skip:
+ pytest.skip(f"Missing backend(s): {','.join(missing_backends)}")
def get_execution_definitions() -> Generator[list[str], None, None]: