aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRaul Farkas <raul.farkas@arm.com>2022-11-29 13:29:04 +0000
committerRaul Farkas <raul.farkas@arm.com>2023-01-10 10:46:07 +0000
commit5800fc990ed1e36ce7d06670f911fbb12a0ec771 (patch)
tree294605295cd2624ba63e6ad3df335a2a4b2700ab /src
parentdcd0bd31985c27e1d07333351b26cf8ad12ad1fd (diff)
downloadmlia-5800fc990ed1e36ce7d06670f911fbb12a0ec771.tar.gz
MLIA-650 Implement new CLI changes
Breaking change in the CLI and API: Sub-commands "optimization", "operators", and "performance" were replaced by "check", which incorporates compatibility and performance checks, and "optimize" which is used for optimization. "get_advice" API was adapted to these CLI changes. API changes: * Remove previous advice category "all" that would perform all three operations (when possible). Replace them with the ability to pass a set of the advice categories. * Update api.get_advice method docstring to reflect new changes. * Set default advice category to COMPATIBILITY * Update core.common.AdviceCategory by changing the "OPERATORS" advice category to "COMPATIBILITY" and removing "ALL" enum type. Update all subsequent methods that previously used "OPERATORS" to use "COMPATIBILITY". * Update core.context.ExecutionContext to have "COMPATIBILITY" as default advice_category instead of "ALL". * Remove api.generate_supported_operators_report and all related functions from cli.commands, cli.helpers, cli.main, cli.options, core.helpers * Update tests to reflect new API changes. CLI changes: * Update README.md to contain information on the new CLI * Remove the ability to generate supported operators support from MLIA CLI * Replace `mlia ops` and `mlia perf` with the new `mlia check` command that can be used to perform both operations. * Replace `mlia opt` with the new `mlia optimize` command. * Replace `--evaluate-on` flag with `--backend` flag * Replace `--verbose` flag with `--debug` flag (no behaviour change). * Remove the ability for the user to select MLIA working directory. Create and use a temporary directory in /temp instead. * Change behaviour of `--output` flag to not format the content automatically based on file extension anymore. Instead it will simply redirect to a file. * Add the `--json` flag to specfy that the format of the output should be json. * Add command validators that are used to validate inter-dependent flags (e.g. backend validation based on target_profile). * Add support for selecting built-in backends for both `check` and `optimize` commands. * Add new unit tests and update old ones to test the new CLI changes. * Update RELEASES.md * Update copyright notice Change-Id: Ia6340797c7bee3acbbd26601950e5a16ad5602db
Diffstat (limited to 'src')
-rw-r--r--src/mlia/api.py57
-rw-r--r--src/mlia/backend/armnn_tflite_delegate/__init__.py4
-rw-r--r--src/mlia/backend/tosa_checker/__init__.py4
-rw-r--r--src/mlia/backend/vela/__init__.py4
-rw-r--r--src/mlia/cli/command_validators.py113
-rw-r--r--src/mlia/cli/commands.py208
-rw-r--r--src/mlia/cli/config.py38
-rw-r--r--src/mlia/cli/helpers.py36
-rw-r--r--src/mlia/cli/main.py81
-rw-r--r--src/mlia/cli/options.py199
-rw-r--r--src/mlia/core/common.py53
-rw-r--r--src/mlia/core/context.py42
-rw-r--r--src/mlia/core/handlers.py10
-rw-r--r--src/mlia/core/helpers.py6
-rw-r--r--src/mlia/core/reporting.py13
-rw-r--r--src/mlia/target/cortex_a/advice_generation.py12
-rw-r--r--src/mlia/target/cortex_a/advisor.py8
-rw-r--r--src/mlia/target/cortex_a/handlers.py6
-rw-r--r--src/mlia/target/ethos_u/advice_generation.py34
-rw-r--r--src/mlia/target/ethos_u/advisor.py14
-rw-r--r--src/mlia/target/ethos_u/handlers.py6
-rw-r--r--src/mlia/target/tosa/advice_generation.py6
-rw-r--r--src/mlia/target/tosa/advisor.py8
-rw-r--r--src/mlia/target/tosa/handlers.py6
-rw-r--r--src/mlia/utils/types.py4
25 files changed, 523 insertions, 449 deletions
diff --git a/src/mlia/api.py b/src/mlia/api.py
index c7be9ec..2cabf37 100644
--- a/src/mlia/api.py
+++ b/src/mlia/api.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the API functions."""
from __future__ import annotations
@@ -6,18 +6,14 @@ from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
-from typing import Literal
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
+from mlia.core.common import FormattedFilePath
from mlia.core.context import ExecutionContext
-from mlia.core.typing import PathOrFileLike
from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor
-from mlia.target.cortex_a.operators import report as cortex_a_report
from mlia.target.ethos_u.advisor import configure_and_get_ethosu_advisor
-from mlia.target.ethos_u.operators import report as ethos_u_report
from mlia.target.tosa.advisor import configure_and_get_tosa_advisor
-from mlia.target.tosa.operators import report as tosa_report
from mlia.utils.filesystem import get_target
logger = logging.getLogger(__name__)
@@ -26,10 +22,9 @@ logger = logging.getLogger(__name__)
def get_advice(
target_profile: str,
model: str | Path,
- category: Literal["all", "operators", "performance", "optimization"] = "all",
+ category: set[str],
optimization_targets: list[dict[str, Any]] | None = None,
- working_dir: str | Path = "mlia_output",
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
context: ExecutionContext | None = None,
backends: list[str] | None = None,
) -> None:
@@ -42,17 +37,13 @@ def get_advice(
:param target_profile: target profile identifier
:param model: path to the NN model
- :param category: category of the advice. MLIA supports four categories:
- "all", "operators", "performance", "optimization". If not provided
- category "all" is used by default.
+ :param category: set of categories of the advice. MLIA supports three categories:
+ "compatibility", "performance", "optimization". If not provided
+ category "compatibility" is used by default.
:param optimization_targets: optional model optimization targets that
- could be used for generating advice in categories
- "all" and "optimization."
- :param working_dir: path to the directory that will be used for storing
- intermediate files during execution (e.g. converted models)
- :param output: path to the report file. If provided MLIA will save
- report in this location. Format of the report automatically
- detected based on file extension.
+ could be used for generating advice in "optimization" category.
+ :param output: path to the report file. If provided, MLIA will save
+ report in this location.
:param context: optional parameter which represents execution context,
could be used for advanced use cases
:param backends: A list of backends that should be used for the given
@@ -63,13 +54,14 @@ def get_advice(
Getting the advice for the provided target profile and the model
- >>> get_advice("ethos-u55-256", "path/to/the/model")
+ >>> get_advice("ethos-u55-256", "path/to/the/model",
+ {"optimization", "compatibility"})
Getting the advice for the category "performance" and save result report in file
"report.json"
- >>> get_advice("ethos-u55-256", "path/to/the/model", "performance",
- output="report.json")
+ >>> get_advice("ethos-u55-256", "path/to/the/model", {"performance"},
+ output=FormattedFilePath("report.json")
"""
advice_category = AdviceCategory.from_string(category)
@@ -78,10 +70,7 @@ def get_advice(
context.advice_category = advice_category
if context is None:
- context = ExecutionContext(
- advice_category=advice_category,
- working_dir=working_dir,
- )
+ context = ExecutionContext(advice_category=advice_category)
advisor = get_advisor(
context,
@@ -99,7 +88,7 @@ def get_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Find appropriate advisor for the target."""
@@ -123,17 +112,3 @@ def get_advisor(
output,
**extra_args,
)
-
-
-def generate_supported_operators_report(target_profile: str) -> None:
- """Generate a supported operators report based on given target profile."""
- generators_map = {
- "ethos-u55": ethos_u_report,
- "ethos-u65": ethos_u_report,
- "cortex-a": cortex_a_report,
- "tosa": tosa_report,
- }
-
- target = get_target(target_profile)
-
- generators_map[target]()
diff --git a/src/mlia/backend/armnn_tflite_delegate/__init__.py b/src/mlia/backend/armnn_tflite_delegate/__init__.py
index 6d5af42..ccb7e38 100644
--- a/src/mlia/backend/armnn_tflite_delegate/__init__.py
+++ b/src/mlia/backend/armnn_tflite_delegate/__init__.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Arm NN TensorFlow Lite delegate backend module."""
from mlia.backend.config import BackendConfiguration
@@ -9,7 +9,7 @@ from mlia.core.common import AdviceCategory
registry.register(
"ArmNNTFLiteDelegate",
BackendConfiguration(
- supported_advice=[AdviceCategory.OPERATORS],
+ supported_advice=[AdviceCategory.COMPATIBILITY],
supported_systems=None,
backend_type=BackendType.BUILTIN,
),
diff --git a/src/mlia/backend/tosa_checker/__init__.py b/src/mlia/backend/tosa_checker/__init__.py
index 19fc8be..c06a122 100644
--- a/src/mlia/backend/tosa_checker/__init__.py
+++ b/src/mlia/backend/tosa_checker/__init__.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA checker backend module."""
from mlia.backend.config import BackendConfiguration
@@ -10,7 +10,7 @@ from mlia.core.common import AdviceCategory
registry.register(
"TOSA-Checker",
BackendConfiguration(
- supported_advice=[AdviceCategory.OPERATORS],
+ supported_advice=[AdviceCategory.COMPATIBILITY],
supported_systems=[System.LINUX_AMD64],
backend_type=BackendType.WHEEL,
),
diff --git a/src/mlia/backend/vela/__init__.py b/src/mlia/backend/vela/__init__.py
index 38a623e..68fbcba 100644
--- a/src/mlia/backend/vela/__init__.py
+++ b/src/mlia/backend/vela/__init__.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Vela backend module."""
from mlia.backend.config import BackendConfiguration
@@ -11,7 +11,7 @@ registry.register(
"Vela",
BackendConfiguration(
supported_advice=[
- AdviceCategory.OPERATORS,
+ AdviceCategory.COMPATIBILITY,
AdviceCategory.PERFORMANCE,
AdviceCategory.OPTIMIZATION,
],
diff --git a/src/mlia/cli/command_validators.py b/src/mlia/cli/command_validators.py
new file mode 100644
index 0000000..1974a1d
--- /dev/null
+++ b/src/mlia/cli/command_validators.py
@@ -0,0 +1,113 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI command validators module."""
+from __future__ import annotations
+
+import argparse
+import logging
+import sys
+
+from mlia.cli.config import get_default_backends
+from mlia.target.registry import supported_backends
+from mlia.utils.filesystem import get_target
+
+logger = logging.getLogger(__name__)
+
+
+def validate_backend(
+ target_profile: str, backend: list[str] | None
+) -> list[str] | None:
+ """Validate backend with given target profile.
+
+ This validator checks whether the given target-profile and backend are
+ compatible with each other.
+ It assumes that prior checks where made on the validity of the target-profile.
+ """
+ target_map = {
+ "ethos-u55": "Ethos-U55",
+ "ethos-u65": "Ethos-U65",
+ "cortex-a": "Cortex-A",
+ "tosa": "TOSA",
+ }
+ target = get_target(target_profile)
+
+ if not backend:
+ return get_default_backends()[target]
+
+ compatible_backends = supported_backends(target_map[target])
+
+ nor_backend = list(map(normalize_string, backend))
+ nor_compat_backend = list(map(normalize_string, compatible_backends))
+
+ incompatible_backends = [
+ backend[i] for i, x in enumerate(nor_backend) if x not in nor_compat_backend
+ ]
+ # Throw an error if any unsupported backends are used
+ if incompatible_backends:
+ raise argparse.ArgumentError(
+ None,
+ f"{', '.join(incompatible_backends)} backend not supported "
+ f"with target-profile {target_profile}.",
+ )
+ return backend
+
+
+def validate_check_target_profile(target_profile: str, category: set[str]) -> None:
+ """Validate whether advice category is compatible with the provided target_profile.
+
+ This validator function raises warnings if any desired advice category is not
+ compatible with the selected target profile. If no operation can be
+ performed as a result of the validation, MLIA exits with error code 0.
+ """
+ incompatible_targets_performance: list[str] = ["tosa", "cortex-a"]
+ incompatible_targets_compatibility: list[str] = []
+
+ # Check which check operation should be performed
+ try_performance = "performance" in category
+ try_compatibility = "compatibility" in category
+
+ # Cross check which of the desired operations can be performed on given
+ # target-profile
+ do_performance = (
+ try_performance and target_profile not in incompatible_targets_performance
+ )
+ do_compatibility = (
+ try_compatibility and target_profile not in incompatible_targets_compatibility
+ )
+
+ # Case: desired operations can be performed with given target profile
+ if (try_performance == do_performance) and (try_compatibility == do_compatibility):
+ return
+
+ warning_message = "\nWARNING: "
+ # Case: performance operation to be skipped
+ if try_performance and not do_performance:
+ warning_message += (
+ "Performance checks skipped as they cannot be "
+ f"performed with target profile {target_profile}."
+ )
+
+ # Case: compatibility operation to be skipped
+ if try_compatibility and not do_compatibility:
+ warning_message += (
+ "Compatibility checks skipped as they cannot be "
+ f"performed with target profile {target_profile}."
+ )
+
+ # Case: at least one operation will be performed
+ if do_compatibility or do_performance:
+ logger.warning(warning_message)
+ return
+
+ # Case: no operation will be performed
+ warning_message += " No operation was performed."
+ logger.warning(warning_message)
+ sys.exit(0)
+
+
+def normalize_string(value: str) -> str:
+ """Given a string return the normalized version.
+
+ E.g. Given "ToSa-cHecker" -> "tosachecker"
+ """
+ return value.lower().replace("-", "")
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 09fe9de..d2242ba 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI commands module.
@@ -13,7 +13,7 @@ be configured. Function 'setup_logging' from module
>>> from mlia.cli.logging import setup_logging
>>> setup_logging(verbose=True)
>>> import mlia.cli.commands as mlia
->>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+>>> mlia.check(ExecutionContext(), "ethos-u55-256",
"path/to/model")
"""
from __future__ import annotations
@@ -22,11 +22,12 @@ import logging
from pathlib import Path
from mlia.api import ExecutionContext
-from mlia.api import generate_supported_operators_report
from mlia.api import get_advice
-from mlia.api import PathOrFileLike
+from mlia.cli.command_validators import validate_backend
+from mlia.cli.command_validators import validate_check_target_profile
from mlia.cli.config import get_installation_manager
from mlia.cli.options import parse_optimization_parameters
+from mlia.cli.options import parse_output_parameters
from mlia.utils.console import create_section_header
logger = logging.getLogger(__name__)
@@ -34,14 +35,15 @@ logger = logging.getLogger(__name__)
CONFIG = create_section_header("ML Inference Advisor configuration")
-def all_tests(
+def check(
ctx: ExecutionContext,
target_profile: str,
- model: str,
- optimization_type: str = "pruning,clustering",
- optimization_target: str = "0.5,32",
- output: PathOrFileLike | None = None,
- evaluate_on: list[str] | None = None,
+ model: str | None = None,
+ compatibility: bool = False,
+ performance: bool = False,
+ output: Path | None = None,
+ json: bool = False,
+ backend: list[str] | None = None,
) -> None:
"""Generate a full report on the input model.
@@ -50,8 +52,6 @@ def all_tests(
- converts the input Keras model into TensorFlow Lite format
- checks the model for operator compatibility on the specified device
- - applies optimizations to the model and estimates the resulting performance
- on both the original and the optimized models
- generates a final report on the steps above
- provides advice on how to (possibly) improve the inference performance
@@ -59,140 +59,63 @@ def all_tests(
:param target_profile: target profile identifier. Will load appropriate parameters
from the profile.json file based on this argument.
:param model: path to the Keras model
- :param optimization_type: list of the optimization techniques separated
- by comma, e.g. 'pruning,clustering'
- :param optimization_target: list of the corresponding targets for
- the provided optimization techniques, e.g. '0.5,32'
+ :param compatibility: flag that identifies whether to run compatibility checks
+ :param performance: flag that identifies whether to run performance checks
:param output: path to the file where the report will be saved
- :param evaluate_on: list of the backends to use for evaluation
+ :param backend: list of the backends to use for evaluation
Example:
- Run command for the target profile ethos-u55-256 with two model optimizations
- and save report in json format locally in the file report.json
+ Run command for the target profile ethos-u55-256 to verify both performance
+ and operator compatibility.
>>> from mlia.api import ExecutionContext
>>> from mlia.cli.logging import setup_logging
>>> setup_logging()
- >>> from mlia.cli.commands import all_tests
- >>> all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
- "model.h5", "pruning,clustering", "0.5,32",
+ >>> from mlia.cli.commands import check
+ >>> check(ExecutionContext(), "ethos-u55-256",
+ "model.h5", compatibility=True, performance=True,
output="report.json")
"""
- opt_params = parse_optimization_parameters(
- optimization_type,
- optimization_target,
- )
-
- get_advice(
- target_profile,
- model,
- "all",
- optimization_targets=opt_params,
- output=output,
- context=ctx,
- backends=evaluate_on,
- )
-
-
-def operators(
- ctx: ExecutionContext,
- target_profile: str,
- model: str | None = None,
- output: PathOrFileLike | None = None,
- supported_ops_report: bool = False,
-) -> None:
- """Print the model's operator list.
-
- This command checks the operator compatibility of the input model with
- the specific target profile. Generates a report of the operator placement
- (NPU or CPU fallback) and advice on how to improve it (if necessary).
-
- :param ctx: execution context
- :param target_profile: target profile identifier. Will load appropriate parameters
- from the profile.json file based on this argument.
- :param model: path to the model, which can be TensorFlow Lite or Keras
- :param output: path to the file where the report will be saved
- :param supported_ops_report: if True then generates supported operators
- report in current directory and exits
-
- Example:
- Run command for the target profile ethos-u55-256 and the provided
- TensorFlow Lite model and print report on the standard output
-
- >>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
- >>> setup_logging()
- >>> from mlia.cli.commands import operators
- >>> operators(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
- "model.tflite")
- """
- if supported_ops_report:
- generate_supported_operators_report(target_profile)
- logger.info("Report saved into SUPPORTED_OPS.md")
- return
-
if not model:
raise Exception("Model is not provided")
- get_advice(
- target_profile,
- model,
- "operators",
- output=output,
- context=ctx,
- )
-
+ formatted_output = parse_output_parameters(output, json)
-def performance(
- ctx: ExecutionContext,
- target_profile: str,
- model: str,
- output: PathOrFileLike | None = None,
- evaluate_on: list[str] | None = None,
-) -> None:
- """Print the model's performance stats.
-
- This command estimates the inference performance of the input model
- on the specified target profile, and generates a report with advice on how
- to improve it.
-
- :param ctx: execution context
- :param target_profile: target profile identifier. Will load appropriate parameters
- from the profile.json file based on this argument.
- :param model: path to the model, which can be TensorFlow Lite or Keras
- :param output: path to the file where the report will be saved
- :param evaluate_on: list of the backends to use for evaluation
+ # Set category based on checks to perform (i.e. "compatibility" and/or
+ # "performance").
+ # If no check type is specified, "compatibility" is the default category.
+ if compatibility and performance:
+ category = {"compatibility", "performance"}
+ elif performance:
+ category = {"performance"}
+ else:
+ category = {"compatibility"}
- Example:
- Run command for the target profile ethos-u55-256 and
- the provided TensorFlow Lite model and print report on the standard output
+ validate_check_target_profile(target_profile, category)
+ validated_backend = validate_backend(target_profile, backend)
- >>> from mlia.api import ExecutionContext
- >>> from mlia.cli.logging import setup_logging
- >>> setup_logging()
- >>> from mlia.cli.commands import performance
- >>> performance(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
- "model.tflite")
- """
get_advice(
target_profile,
model,
- "performance",
- output=output,
+ category,
+ output=formatted_output,
context=ctx,
- backends=evaluate_on,
+ backends=validated_backend,
)
-def optimization(
+def optimize( # pylint: disable=too-many-arguments
ctx: ExecutionContext,
target_profile: str,
model: str,
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
layers_to_optimize: list[str] | None = None,
- output: PathOrFileLike | None = None,
- evaluate_on: list[str] | None = None,
+ output: Path | None = None,
+ json: bool = False,
+ backend: list[str] | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -201,43 +124,54 @@ def optimization(
the inference performance (if possible).
:param ctx: execution context
- :param target: target profile identifier. Will load appropriate parameters
+ :param target_profile: target profile identifier. Will load appropriate parameters
from the profile.json file based on this argument.
:param model: path to the TensorFlow Lite model
- :param optimization_type: list of the optimization techniques separated
- by comma, e.g. 'pruning,clustering'
- :param optimization_target: list of the corresponding targets for
- the provided optimization techniques, e.g. '0.5,32'
+ :param pruning: perform pruning optimization (default if no option specified)
+ :param clustering: perform clustering optimization
+ :param clustering_target: clustering optimization target
+ :param pruning_target: pruning optimization target
:param layers_to_optimize: list of the layers of the model which should be
optimized, if None then all layers are used
:param output: path to the file where the report will be saved
- :param evaluate_on: list of the backends to use for evaluation
+ :param json: set the output format to json
+ :param backend: list of the backends to use for evaluation
Example:
Run command for the target profile ethos-u55-256 and
the provided TensorFlow Lite model and print report on the standard output
>>> from mlia.cli.logging import setup_logging
+ >>> from mlia.api import ExecutionContext
>>> setup_logging()
- >>> from mlia.cli.commands import optimization
- >>> optimization(ExecutionContext(working_dir="mlia_output"),
- target="ethos-u55-256",
- "model.tflite", "pruning", "0.5")
+ >>> from mlia.cli.commands import optimize
+ >>> optimize(ExecutionContext(),
+ target_profile="ethos-u55-256",
+ model="model.tflite", pruning=True,
+ clustering=False, pruning_target=0.5,
+ clustering_target=None)
"""
- opt_params = parse_optimization_parameters(
- optimization_type,
- optimization_target,
- layers_to_optimize=layers_to_optimize,
+ opt_params = (
+ parse_optimization_parameters( # pylint: disable=too-many-function-args
+ pruning,
+ clustering,
+ pruning_target,
+ clustering_target,
+ layers_to_optimize,
+ )
)
+ formatted_output = parse_output_parameters(output, json)
+ validated_backend = validate_backend(target_profile, backend)
+
get_advice(
target_profile,
model,
- "optimization",
+ {"optimization"},
optimization_targets=opt_params,
- output=output,
+ output=formatted_output,
context=ctx,
- backends=evaluate_on,
+ backends=validated_backend,
)
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
index 2d694dc..680b4b6 100644
--- a/src/mlia/cli/config.py
+++ b/src/mlia/cli/config.py
@@ -1,10 +1,13 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Environment configuration functions."""
from __future__ import annotations
import logging
from functools import lru_cache
+from typing import List
+from typing import Optional
+from typing import TypedDict
from mlia.backend.corstone.install import get_corstone_installations
from mlia.backend.install import supported_backends
@@ -14,6 +17,9 @@ from mlia.backend.tosa_checker.install import get_tosa_backend_installation
logger = logging.getLogger(__name__)
+DEFAULT_PRUNING_TARGET = 0.5
+DEFAULT_CLUSTERING_TARGET = 32
+
def get_installation_manager(noninteractive: bool = False) -> InstallationManager:
"""Return installation manager."""
@@ -26,7 +32,7 @@ def get_installation_manager(noninteractive: bool = False) -> InstallationManage
@lru_cache
def get_available_backends() -> list[str]:
"""Return list of the available backends."""
- available_backends = ["Vela"]
+ available_backends = ["Vela", "tosa-checker", "armnn-tflitedelegate"]
# Add backends using backend manager
manager = get_installation_manager()
@@ -41,9 +47,10 @@ def get_available_backends() -> list[str]:
# List of mutually exclusive Corstone backends ordered by priority
_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
+_NON_ETHOS_U_BACKENDS = ("tosa-checker", "armnn-tflitedelegate")
-def get_default_backends() -> list[str]:
+def get_ethos_u_default_backends() -> list[str]:
"""Get default backends for evaluation."""
backends = get_available_backends()
@@ -57,9 +64,34 @@ def get_default_backends() -> list[str]:
]
break
+ # Filter out non ethos-u backends
+ backends = [x for x in backends if x not in _NON_ETHOS_U_BACKENDS]
return backends
def is_corstone_backend(backend: str) -> bool:
"""Check if the given backend is a Corstone backend."""
return backend in _CORSTONE_EXCLUSIVE_PRIORITY
+
+
+BackendCompatibility = TypedDict(
+ "BackendCompatibility",
+ {
+ "partial-match": bool,
+ "backends": List[str],
+ "default-return": Optional[List[str]],
+ "use-custom-return": bool,
+ "custom-return": Optional[List[str]],
+ },
+)
+
+
+def get_default_backends() -> dict[str, list[str]]:
+ """Return default backends for all targets."""
+ ethos_u_defaults = get_ethos_u_default_backends()
+ return {
+ "ethos-u55": ethos_u_defaults,
+ "ethos-u65": ethos_u_defaults,
+ "tosa": ["tosa-checker"],
+ "cortex-a": ["armnn-tflitedelegate"],
+ }
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index acec837..ac64581 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
from __future__ import annotations
@@ -29,9 +29,9 @@ class CLIActionResolver(ActionResolver):
return [
*keras_note,
- "For example: mlia optimization --optimization-type "
- f"pruning,clustering --optimization-target 0.5,32 {model_path}",
- "For more info: mlia optimization --help",
+ f"For example: mlia optimize {model_path} --pruning --clustering "
+ "--pruning-target 0.5 --clustering-target 32",
+ "For more info: mlia optimize --help",
]
@staticmethod
@@ -41,14 +41,17 @@ class CLIActionResolver(ActionResolver):
opt_settings: list[OptimizationSettings],
) -> list[str]:
"""Return specific optimization command description."""
- opt_types = ",".join(opt.optimization_type for opt in opt_settings)
- opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
+ opt_types = " ".join("--" + opt.optimization_type for opt in opt_settings)
+ opt_targs_strings = ["--pruning-target", "--clustering-target"]
+ opt_targs = ",".join(
+ f"{opt_targs_strings[i]} {opt.optimization_target}"
+ for i, opt in enumerate(opt_settings)
+ )
return [
- "For more info: mlia optimization --help",
+ "For more info: mlia optimize --help",
"Optimization command: "
- f"mlia optimization --optimization-type {opt_types} "
- f"--optimization-target {opt_targs}{device_opts} {model_path}",
+ f"mlia optimize {model_path}{device_opts} {opt_types} {opt_targs}",
]
def apply_optimizations(self, **kwargs: Any) -> list[str]:
@@ -65,13 +68,6 @@ class CLIActionResolver(ActionResolver):
return []
- def supported_operators_info(self) -> list[str]:
- """Return command details for generating supported ops report."""
- return [
- "For guidance on supported operators, run: mlia operators "
- "--supported-ops-report",
- ]
-
def check_performance(self) -> list[str]:
"""Return command details for checking performance."""
model_path, device_opts = self._get_model_and_device_opts()
@@ -80,7 +76,7 @@ class CLIActionResolver(ActionResolver):
return [
"Check the estimated performance by running the following command: ",
- f"mlia performance{device_opts} {model_path}",
+ f"mlia check {model_path}{device_opts} --performance",
]
def check_operator_compatibility(self) -> list[str]:
@@ -91,16 +87,16 @@ class CLIActionResolver(ActionResolver):
return [
"Try running the following command to verify that:",
- f"mlia operators{device_opts} {model_path}",
+ f"mlia check {model_path}{device_opts}",
]
def operator_compatibility_details(self) -> list[str]:
"""Return command details for op compatibility."""
- return ["For more details, run: mlia operators --help"]
+ return ["For more details, run: mlia check --help"]
def optimization_details(self) -> list[str]:
"""Return command details for optimization."""
- return ["For more info, see: mlia optimization --help"]
+ return ["For more info, see: mlia optimize --help"]
def _get_model_and_device_opts(
self, separate_device_opts: bool = True
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index ac60308..1102d45 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""CLI main entry point."""
from __future__ import annotations
@@ -8,32 +8,28 @@ import logging
import sys
from functools import partial
from inspect import signature
-from pathlib import Path
from mlia import __version__
from mlia.backend.errors import BackendUnavailableError
from mlia.backend.registry import registry as backend_registry
-from mlia.cli.commands import all_tests
from mlia.cli.commands import backend_install
from mlia.cli.commands import backend_list
from mlia.cli.commands import backend_uninstall
-from mlia.cli.commands import operators
-from mlia.cli.commands import optimization
-from mlia.cli.commands import performance
+from mlia.cli.commands import check
+from mlia.cli.commands import optimize
from mlia.cli.common import CommandInfo
from mlia.cli.helpers import CLIActionResolver
from mlia.cli.logging import setup_logging
from mlia.cli.options import add_backend_install_options
+from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
-from mlia.cli.options import add_custom_supported_operators_options
+from mlia.cli.options import add_check_category_options
from mlia.cli.options import add_debug_options
-from mlia.cli.options import add_evaluation_options
from mlia.cli.options import add_keras_model_options
+from mlia.cli.options import add_model_options
from mlia.cli.options import add_multi_optimization_options
-from mlia.cli.options import add_optional_tflite_model_options
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
-from mlia.cli.options import add_tflite_model_options
from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
@@ -60,50 +56,30 @@ def get_commands() -> list[CommandInfo]:
"""Return commands configuration."""
return [
CommandInfo(
- all_tests,
- ["all"],
- [
- add_target_options,
- add_keras_model_options,
- add_multi_optimization_options,
- add_output_options,
- add_debug_options,
- add_evaluation_options,
- ],
- True,
- ),
- CommandInfo(
- operators,
- ["ops"],
+ check,
+ [],
[
+ add_model_options,
add_target_options,
- add_optional_tflite_model_options,
+ add_backend_options,
+ add_check_category_options,
add_output_options,
- add_custom_supported_operators_options,
add_debug_options,
],
),
CommandInfo(
- performance,
- ["perf"],
- [
- partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
- add_tflite_model_options,
- add_output_options,
- add_debug_options,
- add_evaluation_options,
- ],
- ),
- CommandInfo(
- optimization,
- ["opt"],
+ optimize,
+ [],
[
- partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
add_keras_model_options,
+ partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
+ partial(
+ add_backend_options,
+ backends_to_skip=["tosa-checker", "armnn-tflitedelegate"],
+ ),
add_multi_optimization_options,
add_output_options,
add_debug_options,
- add_evaluation_options,
],
),
]
@@ -184,13 +160,12 @@ def setup_context(
) -> tuple[ExecutionContext, dict]:
"""Set up context and resolve function parameters."""
ctx = ExecutionContext(
- working_dir=args.working_dir,
- verbose="verbose" in args and args.verbose,
+ verbose="debug" in args and args.debug,
action_resolver=CLIActionResolver(vars(args)),
)
# these parameters should not be passed into command function
- skipped_params = ["func", "command", "working_dir", "verbose"]
+ skipped_params = ["func", "command", "debug"]
# pass these parameters only if command expects them
expected_params = [context_var_name]
@@ -219,6 +194,9 @@ def run_command(args: argparse.Namespace) -> int:
try:
logger.info(INFO_MESSAGE)
+ logger.info(
+ "\nThis execution of MLIA uses working directory: %s", ctx.working_dir
+ )
args.func(**func_args)
return 0
except KeyboardInterrupt:
@@ -246,22 +224,19 @@ def run_command(args: argparse.Namespace) -> int:
f"Please check the log files in the {ctx.logs_path} for more details"
)
if not ctx.verbose:
- err_advice_message += ", or enable verbose mode (--verbose)"
+ err_advice_message += ", or enable debug mode (--debug)"
logger.error(err_advice_message)
-
+ finally:
+ logger.info(
+ "This execution of MLIA used working directory: %s", ctx.working_dir
+ )
return 1
def init_common_parser() -> argparse.ArgumentParser:
"""Init common parser."""
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
- parser.add_argument(
- "--working-dir",
- default=f"{Path.cwd() / 'mlia_output'}",
- help="Path to the directory where MLIA will store logs, "
- "models, etc. (default: %(default)s)",
- )
return parser
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 8ea4250..bae6219 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for the CLI options."""
from __future__ import annotations
@@ -8,37 +8,48 @@ from pathlib import Path
from typing import Any
from typing import Callable
+from mlia.cli.config import DEFAULT_CLUSTERING_TARGET
+from mlia.cli.config import DEFAULT_PRUNING_TARGET
from mlia.cli.config import get_available_backends
-from mlia.cli.config import get_default_backends
from mlia.cli.config import is_corstone_backend
-from mlia.core.reporting import OUTPUT_FORMATS
+from mlia.core.common import FormattedFilePath
from mlia.utils.filesystem import get_supported_profile_names
-from mlia.utils.types import is_number
+
+
+def add_check_category_options(parser: argparse.ArgumentParser) -> None:
+ """Add check category type options."""
+ parser.add_argument(
+ "--performance", action="store_true", help="Perform performance checks."
+ )
+
+ parser.add_argument(
+ "--compatibility",
+ action="store_true",
+ help="Perform compatibility checks. (default)",
+ )
def add_target_options(
- parser: argparse.ArgumentParser, profiles_to_skip: list[str] | None = None
+ parser: argparse.ArgumentParser,
+ profiles_to_skip: list[str] | None = None,
+ required: bool = True,
) -> None:
"""Add target specific options."""
target_profiles = get_supported_profile_names()
if profiles_to_skip:
target_profiles = [tp for tp in target_profiles if tp not in profiles_to_skip]
- default_target_profile = None
- default_help = ""
- if target_profiles:
- default_target_profile = target_profiles[0]
- default_help = " (default: %(default)s)"
-
target_group = parser.add_argument_group("target options")
target_group.add_argument(
+ "-t",
"--target-profile",
choices=target_profiles,
- default=default_target_profile,
+ required=required,
+ default="",
help="Target profile that will set the target options "
"such as target, mac value, memory mode, etc. "
- f"For the values associated with each target profile "
- f" please refer to the documentation {default_help}.",
+ "For the values associated with each target profile "
+ "please refer to the documentation.",
)
@@ -47,59 +58,47 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
multi_optimization_group = parser.add_argument_group("optimization options")
multi_optimization_group.add_argument(
- "--optimization-type",
- default="pruning,clustering",
- help="List of the optimization types separated by comma (default: %(default)s)",
+ "--pruning", action="store_true", help="Apply pruning optimization."
)
+
multi_optimization_group.add_argument(
- "--optimization-target",
- default="0.5,32",
- help="""List of the optimization targets separated by comma,
- (for pruning this is sparsity between (0,1),
- for clustering this is the number of clusters (positive integer))
- (default: %(default)s)""",
+ "--clustering", action="store_true", help="Apply clustering optimization."
)
+ multi_optimization_group.add_argument(
+ "--pruning-target",
+ type=float,
+ help="Sparsity to be reached during optimization "
+ f"(default: {DEFAULT_PRUNING_TARGET})",
+ )
-def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None:
- """Add optional model specific options."""
- model_group = parser.add_argument_group("TensorFlow Lite model options")
- # make model parameter optional
- model_group.add_argument(
- "model", nargs="?", help="TensorFlow Lite model (optional)"
+ multi_optimization_group.add_argument(
+ "--clustering-target",
+ type=int,
+ help="Number of clusters to reach during optimization "
+ f"(default: {DEFAULT_CLUSTERING_TARGET})",
)
-def add_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
- model_group = parser.add_argument_group("TensorFlow Lite model options")
- model_group.add_argument("model", help="TensorFlow Lite model")
+ parser.add_argument("model", help="TensorFlow Lite model or Keras model")
def add_output_options(parser: argparse.ArgumentParser) -> None:
"""Add output specific options."""
- valid_extensions = OUTPUT_FORMATS
-
- def check_extension(filename: str) -> str:
- """Check extension of the provided file."""
- suffix = Path(filename).suffix
- if suffix.startswith("."):
- suffix = suffix[1:]
-
- if suffix.lower() not in valid_extensions:
- parser.error(f"Unsupported format '{suffix}'")
-
- return filename
-
output_group = parser.add_argument_group("output options")
output_group.add_argument(
+ "-o",
"--output",
- type=check_extension,
- help=(
- "Name of the file where report will be saved. "
- "Report format is automatically detected based on the file extension. "
- f"Supported formats are: {', '.join(valid_extensions)}"
- ),
+ type=Path,
+ help=("Name of the file where the report will be saved."),
+ )
+
+ output_group.add_argument(
+ "--json",
+ action="store_true",
+ help=("Format to use for the output (requires --output argument to be set)."),
)
@@ -107,7 +106,11 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None:
"""Add debug options."""
debug_group = parser.add_argument_group("debug options")
debug_group.add_argument(
- "--verbose", default=False, action="store_true", help="Produce verbose output"
+ "-d",
+ "--debug",
+ default=False,
+ action="store_true",
+ help="Produce verbose output",
)
@@ -117,20 +120,6 @@ def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
model_group.add_argument("model", help="Keras model")
-def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None:
- """Add custom options for the command 'operators'."""
- parser.add_argument(
- "--supported-ops-report",
- action="store_true",
- default=False,
- help=(
- "Generate the SUPPORTED_OPS.md file in the "
- "current working directory and exit "
- "(Ethos-U target profiles only)"
- ),
- )
-
-
def add_backend_install_options(parser: argparse.ArgumentParser) -> None:
"""Add options for the backends configuration."""
@@ -176,10 +165,11 @@ def add_backend_uninstall_options(parser: argparse.ArgumentParser) -> None:
)
-def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
+def add_backend_options(
+ parser: argparse.ArgumentParser, backends_to_skip: list[str] | None = None
+) -> None:
"""Add evaluation options."""
available_backends = get_available_backends()
- default_backends = get_default_backends()
def only_one_corstone_checker() -> Callable:
"""
@@ -202,41 +192,70 @@ def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
return check
- evaluation_group = parser.add_argument_group("evaluation options")
+ # Remove backends to skip
+ if backends_to_skip:
+ available_backends = [
+ x for x in available_backends if x not in backends_to_skip
+ ]
+
+ evaluation_group = parser.add_argument_group("backend options")
evaluation_group.add_argument(
- "--evaluate-on",
- help="Backends to use for evaluation (default: %(default)s)",
- nargs="*",
+ "-b",
+ "--backend",
+ help="Backends to use for evaluation.",
+ nargs="+",
choices=available_backends,
- default=default_backends,
type=only_one_corstone_checker(),
)
+def parse_output_parameters(path: Path | None, json: bool) -> FormattedFilePath | None:
+ """Parse and return path and file format as FormattedFilePath."""
+ if not path and json:
+ raise argparse.ArgumentError(
+ None,
+ "To enable JSON output you need to specify the output path. "
+ "(e.g. --output out.json --json)",
+ )
+ if not path:
+ return None
+ if json:
+ return FormattedFilePath(path, "json")
+
+ return FormattedFilePath(path, "plain_text")
+
+
def parse_optimization_parameters(
- optimization_type: str,
- optimization_target: str,
- sep: str = ",",
+ pruning: bool = False,
+ clustering: bool = False,
+ pruning_target: float | None = None,
+ clustering_target: int | None = None,
layers_to_optimize: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
- if not optimization_type:
- raise Exception("Optimization type is not provided")
+ opt_types = []
+ opt_targets = []
- if not optimization_target:
- raise Exception("Optimization target is not provided")
+ if clustering_target and not clustering:
+ raise argparse.ArgumentError(
+ None,
+ "To enable clustering optimization you need to include the "
+ "`--clustering` flag in your command.",
+ )
- opt_types = optimization_type.split(sep)
- opt_targets = optimization_target.split(sep)
+ if not pruning_target:
+ pruning_target = DEFAULT_PRUNING_TARGET
- if len(opt_types) != len(opt_targets):
- raise Exception("Wrong number of optimization targets and types")
+ if not clustering_target:
+ clustering_target = DEFAULT_CLUSTERING_TARGET
- non_numeric_targets = [
- opt_target for opt_target in opt_targets if not is_number(opt_target)
- ]
- if len(non_numeric_targets) > 0:
- raise Exception("Non numeric value for the optimization target")
+ if (pruning is False and clustering is False) or pruning:
+ opt_types.append("pruning")
+ opt_targets.append(pruning_target)
+
+ if clustering:
+ opt_types.append("clustering")
+ opt_targets.append(clustering_target)
optimizer_params = [
{
@@ -256,7 +275,7 @@ def get_target_profile_opts(device_args: dict | None) -> list[str]:
return []
parser = argparse.ArgumentParser()
- add_target_options(parser)
+ add_target_options(parser, required=False)
args = parser.parse_args([])
params_name = {
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
index 6c9dde1..53df001 100644
--- a/src/mlia/core/common.py
+++ b/src/mlia/core/common.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Common module.
@@ -13,6 +13,9 @@ from enum import auto
from enum import Flag
from typing import Any
+from mlia.core.typing import OutputFormat
+from mlia.core.typing import PathOrFileLike
+
# This type is used as type alias for the items which are being passed around
# in advisor workflow. There are no restrictions on the type of the
# object. This alias used only to emphasize the nature of the input/output
@@ -20,31 +23,55 @@ from typing import Any
DataItem = Any
+class FormattedFilePath:
+ """Class used to keep track of the format that a path points to."""
+
+ def __init__(self, path: PathOrFileLike, fmt: OutputFormat = "plain_text") -> None:
+ """Init FormattedFilePath."""
+ self._path = path
+ self._fmt = fmt
+
+ @property
+ def fmt(self) -> OutputFormat:
+ """Return file format."""
+ return self._fmt
+
+ @property
+ def path(self) -> PathOrFileLike:
+ """Return file path."""
+ return self._path
+
+ def __eq__(self, other: object) -> bool:
+ """Check for equality with other objects."""
+ if isinstance(other, FormattedFilePath):
+ return other.fmt == self.fmt and other.path == self.path
+
+ return False
+
+ def __repr__(self) -> str:
+ """Represent object."""
+ return f"FormattedFilePath {self.path=}, {self.fmt=}"
+
+
class AdviceCategory(Flag):
"""Advice category.
Enumeration of advice categories supported by ML Inference Advisor.
"""
- OPERATORS = auto()
+ COMPATIBILITY = auto()
PERFORMANCE = auto()
OPTIMIZATION = auto()
- ALL = (
- # pylint: disable=unsupported-binary-operation
- OPERATORS
- | PERFORMANCE
- | OPTIMIZATION
- # pylint: enable=unsupported-binary-operation
- )
@classmethod
- def from_string(cls, value: str) -> AdviceCategory:
+ def from_string(cls, values: set[str]) -> set[AdviceCategory]:
"""Resolve enum value from string value."""
category_names = [item.name for item in AdviceCategory]
- if not value or value.upper() not in category_names:
- raise Exception(f"Invalid advice category {value}")
+ for advice_value in values:
+ if advice_value.upper() not in category_names:
+ raise Exception(f"Invalid advice category {advice_value}")
- return AdviceCategory[value.upper()]
+ return {AdviceCategory[value.upper()] for value in values}
class NamedEntity(ABC):
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
index a4737bb..94aa885 100644
--- a/src/mlia/core/context.py
+++ b/src/mlia/core/context.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Context module.
@@ -10,6 +10,7 @@ parameters).
from __future__ import annotations
import logging
+import tempfile
from abc import ABC
from abc import abstractmethod
from pathlib import Path
@@ -54,7 +55,7 @@ class Context(ABC):
@property
@abstractmethod
- def advice_category(self) -> AdviceCategory:
+ def advice_category(self) -> set[AdviceCategory]:
"""Return advice category."""
@property
@@ -71,7 +72,7 @@ class Context(ABC):
def update(
self,
*,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
@@ -79,11 +80,11 @@ class Context(ABC):
def category_enabled(self, category: AdviceCategory) -> bool:
"""Check if category enabled."""
- return category == self.advice_category
+ return category in self.advice_category
def any_category_enabled(self, *categories: AdviceCategory) -> bool:
"""Return true if any category is enabled."""
- return self.advice_category in categories
+ return all(category in self.advice_category for category in categories)
def register_event_handlers(self) -> None:
"""Register event handlers."""
@@ -96,7 +97,7 @@ class ExecutionContext(Context):
def __init__(
self,
*,
- advice_category: AdviceCategory = AdviceCategory.ALL,
+ advice_category: set[AdviceCategory] = None,
config_parameters: Mapping[str, Any] | None = None,
working_dir: str | Path | None = None,
event_handlers: list[EventHandler] | None = None,
@@ -108,7 +109,7 @@ class ExecutionContext(Context):
) -> None:
"""Init execution context.
- :param advice_category: requested advice category
+ :param advice_category: requested advice categories
:param config_parameters: dictionary like object with input parameters
:param working_dir: path to the directory that will be used as a place
to store temporary files, logs, models. If not provided then
@@ -124,13 +125,13 @@ class ExecutionContext(Context):
:param action_resolver: instance of the action resolver that could make
advice actionable
"""
- self._advice_category = advice_category
+ self._advice_category = advice_category or {AdviceCategory.COMPATIBILITY}
self._config_parameters = config_parameters
- self._working_dir_path = Path.cwd()
if working_dir:
self._working_dir_path = Path(working_dir)
- self._working_dir_path.mkdir(exist_ok=True)
+ else:
+ self._working_dir_path = generate_temp_workdir()
self._event_handlers = event_handlers
self._event_publisher = event_publisher or DefaultEventPublisher()
@@ -140,12 +141,17 @@ class ExecutionContext(Context):
self._action_resolver = action_resolver or APIActionResolver()
@property
- def advice_category(self) -> AdviceCategory:
+ def working_dir(self) -> Path:
+ """Return working dir path."""
+ return self._working_dir_path
+
+ @property
+ def advice_category(self) -> set[AdviceCategory]:
"""Return advice category."""
return self._advice_category
@advice_category.setter
- def advice_category(self, advice_category: AdviceCategory) -> None:
+ def advice_category(self, advice_category: set[AdviceCategory]) -> None:
"""Setter for the advice category."""
self._advice_category = advice_category
@@ -194,7 +200,7 @@ class ExecutionContext(Context):
def update(
self,
*,
- advice_category: AdviceCategory,
+ advice_category: set[AdviceCategory],
event_handlers: list[EventHandler],
config_parameters: Mapping[str, Any],
) -> None:
@@ -206,7 +212,9 @@ class ExecutionContext(Context):
def __str__(self) -> str:
"""Return string representation."""
category = (
- "<not set>" if self.advice_category is None else self.advice_category.name
+ "<not set>"
+ if self.advice_category is None
+ else {x.name for x in self.advice_category}
)
return (
@@ -215,3 +223,9 @@ class ExecutionContext(Context):
f"config_parameters={self.config_parameters}, "
f"verbose={self.verbose}"
)
+
+
+def generate_temp_workdir() -> Path:
+ """Generate a temporary working dir and returns the path."""
+ working_dir = tempfile.mkdtemp(suffix=None, prefix="mlia-", dir=None)
+ return Path(working_dir)
diff --git a/src/mlia/core/handlers.py b/src/mlia/core/handlers.py
index a3255ae..6e50934 100644
--- a/src/mlia/core/handlers.py
+++ b/src/mlia/core/handlers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handlers module."""
from __future__ import annotations
@@ -9,6 +9,7 @@ from typing import Callable
from mlia.core.advice_generation import Advice
from mlia.core.advice_generation import AdviceEvent
+from mlia.core.common import FormattedFilePath
from mlia.core.events import ActionFinishedEvent
from mlia.core.events import ActionStartedEvent
from mlia.core.events import AdviceStageFinishedEvent
@@ -26,7 +27,6 @@ from mlia.core.events import ExecutionFinishedEvent
from mlia.core.events import ExecutionStartedEvent
from mlia.core.reporting import Report
from mlia.core.reporting import Reporter
-from mlia.core.reporting import resolve_output_format
from mlia.core.typing import PathOrFileLike
from mlia.utils.console import create_section_header
@@ -101,12 +101,12 @@ class WorkflowEventsHandler(SystemEventsHandler):
def __init__(
self,
formatter_resolver: Callable[[Any], Callable[[Any], Report]],
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
) -> None:
"""Init event handler."""
- output_format = resolve_output_format(output)
+ output_format = output.fmt if output else "plain_text"
self.reporter = Reporter(formatter_resolver, output_format)
- self.output = output
+ self.output = output.path if output else None
self.advice: list[Advice] = []
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
index f4a9df6..ed43d04 100644
--- a/src/mlia/core/helpers.py
+++ b/src/mlia/core/helpers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for various helper classes."""
# pylint: disable=unused-argument
@@ -14,10 +14,6 @@ class ActionResolver:
"""Return action details for applying optimizations."""
return []
- def supported_operators_info(self) -> list[str]:
- """Return action details for generating supported ops report."""
- return []
-
def check_performance(self) -> list[str]:
"""Return action details for checking performance."""
return []
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
index b96a6b5..19644b2 100644
--- a/src/mlia/core/reporting.py
+++ b/src/mlia/core/reporting.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Reporting module."""
from __future__ import annotations
@@ -639,14 +639,3 @@ def _apply_format_parameters(
return report
return wrapper
-
-
-def resolve_output_format(output: PathOrFileLike | None) -> OutputFormat:
- """Resolve output format based on the output name."""
- if isinstance(output, (str, Path)):
- format_from_filename = Path(output).suffix.lstrip(".")
-
- if format_from_filename in OUTPUT_FORMATS:
- return cast(OutputFormat, format_from_filename)
-
- return "plain_text"
diff --git a/src/mlia/target/cortex_a/advice_generation.py b/src/mlia/target/cortex_a/advice_generation.py
index b68106e..98e8c06 100644
--- a/src/mlia/target/cortex_a/advice_generation.py
+++ b/src/mlia/target/cortex_a/advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A advice generation."""
from functools import singledispatchmethod
@@ -29,7 +29,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
"""Produce advice."""
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_cortex_a_compatible(
self, data_item: ModelIsCortexACompatible
) -> None:
@@ -43,7 +43,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_not_cortex_a_compatible(
self, data_item: ModelIsNotCortexACompatible
) -> None:
@@ -83,7 +83,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_not_tflite_compatible(
self, data_item: ModelIsNotTFLiteCompatible
) -> None:
@@ -127,7 +127,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_tflite_check_failed(
self, _data_item: TFLiteCompatibilityCheckFailed
) -> None:
@@ -140,7 +140,7 @@ class CortexAAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_has_custom_operators(
self, _data_item: ModelHasCustomOperators
) -> None:
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index 5912e38..b649f0d 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cortex-A MLIA module."""
from __future__ import annotations
@@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
+from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
-from mlia.core.typing import PathOrFileLike
from mlia.target.cortex_a.advice_generation import CortexAAdviceProducer
from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer
@@ -38,7 +38,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
collectors: list[DataCollector] = []
- if AdviceCategory.OPERATORS in context.advice_category:
+ if context.category_enabled(AdviceCategory.COMPATIBILITY):
collectors.append(CortexAOperatorCompatibility(model))
return collectors
@@ -67,7 +67,7 @@ def configure_and_get_cortexa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
diff --git a/src/mlia/target/cortex_a/handlers.py b/src/mlia/target/cortex_a/handlers.py
index b2d5faa..d6acde5 100644
--- a/src/mlia/target/cortex_a/handlers.py
+++ b/src/mlia/target/cortex_a/handlers.py
@@ -1,13 +1,13 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
from __future__ import annotations
import logging
+from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
-from mlia.core.typing import PathOrFileLike
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.target.cortex_a.events import CortexAAdvisorEventHandler
from mlia.target.cortex_a.events import CortexAAdvisorStartedEvent
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class CortexAEventHandler(WorkflowEventsHandler, CortexAAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: PathOrFileLike | None = None) -> None:
+ def __init__(self, output: FormattedFilePath | None = None) -> None:
"""Init event handler."""
super().__init__(cortex_a_formatters, output)
diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py
index edd78fd..daae4f4 100644
--- a/src/mlia/target/ethos_u/advice_generation.py
+++ b/src/mlia/target/ethos_u/advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U advice generation."""
from __future__ import annotations
@@ -26,7 +26,7 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
"""Produce advice."""
@produce_advice.register
- @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_cpu_only_ops(self, data_item: HasCPUOnlyOperators) -> None:
"""Advice for CPU only operators."""
cpu_only_ops = ",".join(sorted(set(data_item.cpu_only_ops)))
@@ -40,11 +40,10 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
"Using operators that are supported by the NPU will "
"improve performance.",
]
- + self.context.action_resolver.supported_operators_info()
)
@produce_advice.register
- @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_unsupported_operators(
self, data_item: HasUnsupportedOnNPUOperators
) -> None:
@@ -60,21 +59,25 @@ class EthosUAdviceProducer(FactBasedAdviceProducer):
)
@produce_advice.register
- @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_all_operators_supported(
self, _data_item: AllOperatorsSupportedOnNPU
) -> None:
"""Advice if all operators supported."""
- self.add_advice(
- [
- "You don't have any unsupported operators, your model will "
- "run completely on NPU."
- ]
- + self.context.action_resolver.check_performance()
- )
+ advice = [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU."
+ ]
+ if self.context.advice_category != (
+ AdviceCategory.COMPATIBILITY,
+ AdviceCategory.PERFORMANCE,
+ ):
+ advice += self.context.action_resolver.check_performance()
+
+ self.add_advice(advice)
@produce_advice.register
- @advice_category(AdviceCategory.OPTIMIZATION, AdviceCategory.ALL)
+ @advice_category(AdviceCategory.OPTIMIZATION)
def handle_optimization_results(self, data_item: OptimizationResults) -> None:
"""Advice based on optimization results."""
if not data_item.diffs or len(data_item.diffs) != 1:
@@ -202,5 +205,6 @@ class EthosUStaticAdviceProducer(ContextAwareAdviceProducer):
)
],
}
-
- return advice_per_category.get(self.context.advice_category, [])
+ if len(self.context.advice_category) == 1:
+ return advice_per_category.get(list(self.context.advice_category)[0], [])
+ return []
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index b9d64ff..640c3e1 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Ethos-U MLIA module."""
from __future__ import annotations
@@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
from mlia.core.common import AdviceCategory
+from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
-from mlia.core.typing import PathOrFileLike
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
@@ -46,7 +46,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
collectors: list[DataCollector] = []
- if AdviceCategory.OPERATORS in context.advice_category:
+ if context.category_enabled(AdviceCategory.COMPATIBILITY):
collectors.append(EthosUOperatorCompatibility(model, device))
# Performance and optimization are mutually exclusive.
@@ -57,18 +57,18 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
raise Exception(
"Command 'optimization' is not supported for TensorFlow Lite files."
)
- if AdviceCategory.PERFORMANCE in context.advice_category:
+ if context.category_enabled(AdviceCategory.PERFORMANCE):
collectors.append(EthosUPerformance(model, device, backends))
else:
# Keras/SavedModel: Prefer optimization
- if AdviceCategory.OPTIMIZATION in context.advice_category:
+ if context.category_enabled(AdviceCategory.OPTIMIZATION):
optimization_settings = self._get_optimization_settings(context)
collectors.append(
EthosUOptimizationPerformance(
model, device, optimization_settings, backends
)
)
- elif AdviceCategory.PERFORMANCE in context.advice_category:
+ elif context.category_enabled(AdviceCategory.PERFORMANCE):
collectors.append(EthosUPerformance(model, device, backends))
return collectors
@@ -126,7 +126,7 @@ def configure_and_get_ethosu_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Ethos-U advisor."""
diff --git a/src/mlia/target/ethos_u/handlers.py b/src/mlia/target/ethos_u/handlers.py
index 84a9554..91f6015 100644
--- a/src/mlia/target/ethos_u/handlers.py
+++ b/src/mlia/target/ethos_u/handlers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Event handler."""
from __future__ import annotations
@@ -6,9 +6,9 @@ from __future__ import annotations
import logging
from mlia.backend.vela.compat import Operators
+from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
-from mlia.core.typing import PathOrFileLike
from mlia.target.ethos_u.events import EthosUAdvisorEventHandler
from mlia.target.ethos_u.events import EthosUAdvisorStartedEvent
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
"""CLI event handler."""
- def __init__(self, output: PathOrFileLike | None = None) -> None:
+ def __init__(self, output: FormattedFilePath | None = None) -> None:
"""Init event handler."""
super().__init__(ethos_u_formatters, output)
diff --git a/src/mlia/target/tosa/advice_generation.py b/src/mlia/target/tosa/advice_generation.py
index f531b84..b8b9abf 100644
--- a/src/mlia/target/tosa/advice_generation.py
+++ b/src/mlia/target/tosa/advice_generation.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advice generation."""
from functools import singledispatchmethod
@@ -19,7 +19,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer):
"""Produce advice."""
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_tosa_compatible(
self, _data_item: ModelIsTOSACompatible
) -> None:
@@ -27,7 +27,7 @@ class TOSAAdviceProducer(FactBasedAdviceProducer):
self.add_advice(["Model is fully TOSA compatible."])
@produce_advice.register
- @advice_category(AdviceCategory.ALL, AdviceCategory.OPERATORS)
+ @advice_category(AdviceCategory.COMPATIBILITY)
def handle_model_is_not_tosa_compatible(
self, _data_item: ModelIsNotTOSACompatible
) -> None:
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 2739dfd..4851113 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA advisor."""
from __future__ import annotations
@@ -10,12 +10,12 @@ from mlia.core.advice_generation import AdviceCategory
from mlia.core.advice_generation import AdviceProducer
from mlia.core.advisor import DefaultInferenceAdvisor
from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import FormattedFilePath
from mlia.core.context import Context
from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
-from mlia.core.typing import PathOrFileLike
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
from mlia.target.tosa.data_analysis import TOSADataAnalyzer
@@ -38,7 +38,7 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
collectors: list[DataCollector] = []
- if AdviceCategory.OPERATORS in context.advice_category:
+ if context.category_enabled(AdviceCategory.COMPATIBILITY):
collectors.append(TOSAOperatorCompatibility(model))
return collectors
@@ -69,7 +69,7 @@ def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str,
model: str | Path,
- output: PathOrFileLike | None = None,
+ output: FormattedFilePath | None = None,
**_extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
diff --git a/src/mlia/target/tosa/handlers.py b/src/mlia/target/tosa/handlers.py
index 863558c..1037ba1 100644
--- a/src/mlia/target/tosa/handlers.py
+++ b/src/mlia/target/tosa/handlers.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""TOSA Advisor event handlers."""
# pylint: disable=R0801
@@ -7,9 +7,9 @@ from __future__ import annotations
import logging
from mlia.backend.tosa_checker.compat import TOSACompatibilityInfo
+from mlia.core.common import FormattedFilePath
from mlia.core.events import CollectedDataEvent
from mlia.core.handlers import WorkflowEventsHandler
-from mlia.core.typing import PathOrFileLike
from mlia.target.tosa.events import TOSAAdvisorEventHandler
from mlia.target.tosa.events import TOSAAdvisorStartedEvent
from mlia.target.tosa.reporters import tosa_formatters
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class TOSAEventHandler(WorkflowEventsHandler, TOSAAdvisorEventHandler):
"""Event handler for TOSA advisor."""
- def __init__(self, output: PathOrFileLike | None = None) -> None:
+ def __init__(self, output: FormattedFilePath | None = None) -> None:
"""Init event handler."""
super().__init__(tosa_formatters, output)
diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py
index ea067b8..0769968 100644
--- a/src/mlia/utils/types.py
+++ b/src/mlia/utils/types.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Types related utility functions."""
from __future__ import annotations
@@ -19,7 +19,7 @@ def is_number(value: str) -> bool:
"""Return true if string contains a number."""
try:
float(value)
- except ValueError:
+ except (ValueError, TypeError):
return False
return True