aboutsummaryrefslogtreecommitdiff
path: root/src/mlia
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia')
-rw-r--r--src/mlia/__init__.py22
-rw-r--r--src/mlia/api.py162
-rw-r--r--src/mlia/cli/__init__.py3
-rw-r--r--src/mlia/cli/commands.py276
-rw-r--r--src/mlia/cli/common.py38
-rw-r--r--src/mlia/cli/config.py64
-rw-r--r--src/mlia/cli/helpers.py116
-rw-r--r--src/mlia/cli/logging.py117
-rw-r--r--src/mlia/cli/main.py280
-rw-r--r--src/mlia/cli/options.py280
-rw-r--r--src/mlia/core/__init__.py21
-rw-r--r--src/mlia/core/_typing.py12
-rw-r--r--src/mlia/core/advice_generation.py106
-rw-r--r--src/mlia/core/advisor.py21
-rw-r--r--src/mlia/core/common.py47
-rw-r--r--src/mlia/core/context.py218
-rw-r--r--src/mlia/core/data_analysis.py70
-rw-r--r--src/mlia/core/data_collection.py37
-rw-r--r--src/mlia/core/errors.py18
-rw-r--r--src/mlia/core/events.py455
-rw-r--r--src/mlia/core/helpers.py38
-rw-r--r--src/mlia/core/mixins.py54
-rw-r--r--src/mlia/core/performance.py47
-rw-r--r--src/mlia/core/reporting.py762
-rw-r--r--src/mlia/core/workflow.py216
-rw-r--r--src/mlia/devices/__init__.py3
-rw-r--r--src/mlia/devices/config.py11
-rw-r--r--src/mlia/devices/ethosu/__init__.py3
-rw-r--r--src/mlia/devices/ethosu/advice_generation.py209
-rw-r--r--src/mlia/devices/ethosu/advisor.py151
-rw-r--r--src/mlia/devices/ethosu/config.py89
-rw-r--r--src/mlia/devices/ethosu/data_analysis.py154
-rw-r--r--src/mlia/devices/ethosu/data_collection.py188
-rw-r--r--src/mlia/devices/ethosu/events.py24
-rw-r--r--src/mlia/devices/ethosu/handlers.py146
-rw-r--r--src/mlia/devices/ethosu/operators.py14
-rw-r--r--src/mlia/devices/ethosu/performance.py257
-rw-r--r--src/mlia/devices/ethosu/reporters.py398
-rw-r--r--src/mlia/nn/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/config.py134
-rw-r--r--src/mlia/nn/tensorflow/optimizations/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py109
-rw-r--r--src/mlia/nn/tensorflow/optimizations/common.py29
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py168
-rw-r--r--src/mlia/nn/tensorflow/optimizations/select.py179
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py296
-rw-r--r--src/mlia/nn/tensorflow/utils.py149
-rw-r--r--src/mlia/resources/aiet/applications/APPLICATIONS.txt6
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json18
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axfbin0 -> 426496 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axfbin0 -> 426544 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axfbin0 -> 2524028 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axfbin0 -> 426488 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axfbin0 -> 426536 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/systems/SYSTEMS.txt10
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json80
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300/aiet-config.json80
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json42
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310/aiet-config.json42
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license3
-rw-r--r--src/mlia/resources/profiles.json20
-rw-r--r--src/mlia/resources/profiles.json.license3
-rw-r--r--src/mlia/resources/vela/vela.ini75
-rw-r--r--src/mlia/tools/__init__.py3
-rw-r--r--src/mlia/tools/aiet_wrapper.py435
-rw-r--r--src/mlia/tools/metadata/__init__.py3
-rw-r--r--src/mlia/tools/metadata/common.py290
-rw-r--r--src/mlia/tools/metadata/corstone.py402
-rw-r--r--src/mlia/tools/vela_wrapper.py500
-rw-r--r--src/mlia/utils/__init__.py3
-rw-r--r--src/mlia/utils/console.py97
-rw-r--r--src/mlia/utils/download.py89
-rw-r--r--src/mlia/utils/filesystem.py124
-rw-r--r--src/mlia/utils/logging.py120
-rw-r--r--src/mlia/utils/misc.py9
-rw-r--r--src/mlia/utils/proc.py164
-rw-r--r--src/mlia/utils/types.py37
95 files changed, 9094 insertions, 0 deletions
diff --git a/src/mlia/__init__.py b/src/mlia/__init__.py
new file mode 100644
index 0000000..ed9ae87
--- /dev/null
+++ b/src/mlia/__init__.py
@@ -0,0 +1,22 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Init of MLIA."""
+import logging
+import os
+
+import pkg_resources
+
+# redirect warnings to logging
+logging.captureWarnings(True)
+
+
+# as TensorFlow tries to configure root logger
+# it should be configured before importing TensorFlow
+root_logger = logging.getLogger()
+root_logger.addHandler(logging.NullHandler())
+
+
+# disable TensorFlow warning messages
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+
+__version__ = pkg_resources.get_distribution("mlia").version
diff --git a/src/mlia/api.py b/src/mlia/api.py
new file mode 100644
index 0000000..53ea4c8
--- /dev/null
+++ b/src/mlia/api.py
@@ -0,0 +1,162 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the API functions."""
+import logging
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Union
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import AdviceCategory
+from mlia.core.context import ExecutionContext
+from mlia.core.events import EventHandler
+from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
+from mlia.devices.ethosu.handlers import EthosUEventHandler
+
+
+logger = logging.getLogger(__name__)
+
+
+_DEFAULT_OPTIMIZATION_TARGETS = [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ {
+ "optimization_type": "clustering",
+ "optimization_target": 32,
+ "layers_to_optimize": None,
+ },
+]
+
+
+def get_advice(
+ target_profile: str,
+ model: Union[Path, str],
+ category: Literal["all", "operators", "performance", "optimization"] = "all",
+ optimization_targets: Optional[List[Dict[str, Any]]] = None,
+ working_dir: Union[str, Path] = "mlia_output",
+ output: Optional[PathOrFileLike] = None,
+ context: Optional[ExecutionContext] = None,
+ backends: Optional[List[str]] = None,
+) -> None:
+ """Get the advice.
+
+ This function represents an entry point to the library API.
+
+ Based on provided parameters it will collect and analyze the data
+ and produce the advice.
+
+ :param target_profile: target profile identifier
+ :param model: path to the NN model
+ :param category: category of the advice. MLIA supports four categories:
+ "all", "operators", "performance", "optimization". If not provided
+ category "all" is used by default.
+ :param optimization_targets: optional model optimization targets that
+ could be used for generating advice in categories
+ "all" and "optimization."
+ :param working_dir: path to the directory that will be used for storing
+ intermediate files during execution (e.g. converted models)
+ :param output: path to the report file. If provided MLIA will save
+ report in this location. Format of the report automatically
+ detected based on file extension.
+ :param context: optional parameter which represents execution context,
+ could be used for advanced use cases
+ :param backends: A list of backends that should be used for the given
+ target. Default settings will be used if None.
+
+
+ Examples:
+ NB: Before launching MLIA, the logging functionality should be configured!
+
+ Getting the advice for the provided target profile and the model
+
+ >>> get_advice("ethos-u55-256", "path/to/the/model")
+
+ Getting the advice for the category "performance" and save result report in file
+ "report.json"
+
+ >>> get_advice("ethos-u55-256", "path/to/the/model", "performance",
+ output="report.json")
+
+ """
+ advice_category = AdviceCategory.from_string(category)
+ config_parameters = _get_config_parameters(
+ model, target_profile, backends, optimization_targets
+ )
+ event_handlers = _get_event_handlers(output)
+
+ if context is not None:
+ if context.advice_category is None:
+ context.advice_category = advice_category
+
+ if context.config_parameters is None:
+ context.config_parameters = config_parameters
+
+ if context.event_handlers is None:
+ context.event_handlers = event_handlers
+
+ if context is None:
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=working_dir,
+ config_parameters=config_parameters,
+ event_handlers=event_handlers,
+ )
+
+ advisor = _get_advisor(target_profile)
+ advisor.run(context)
+
+
+def _get_advisor(target: Optional[str]) -> InferenceAdvisor:
+ """Find appropriate advisor for the target."""
+ if not target:
+ raise Exception("Target is not provided")
+
+ return EthosUInferenceAdvisor()
+
+
+def _get_config_parameters(
+ model: Union[Path, str],
+ target_profile: str,
+ backends: Optional[List[str]],
+ optimization_targets: Optional[List[Dict[str, Any]]],
+) -> Dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: Dict[str, Any] = {
+ "ethos_u_inference_advisor": {
+ "model": model,
+ "device": {
+ "target_profile": target_profile,
+ },
+ },
+ }
+ # Specifying backends is optional (default is used)
+ if backends is not None:
+ advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
+
+ if not optimization_targets:
+ optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
+
+ advisor_parameters.update(
+ {
+ "ethos_u_model_optimizations": {
+ "optimizations": [
+ optimization_targets,
+ ],
+ },
+ }
+ )
+
+ return advisor_parameters
+
+
+def _get_event_handlers(output: Optional[PathOrFileLike]) -> List[EventHandler]:
+ """Return list of the event handlers."""
+ return [EthosUEventHandler(output)]
diff --git a/src/mlia/cli/__init__.py b/src/mlia/cli/__init__.py
new file mode 100644
index 0000000..f50778e
--- /dev/null
+++ b/src/mlia/cli/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI module."""
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
new file mode 100644
index 0000000..45c7c32
--- /dev/null
+++ b/src/mlia/cli/commands.py
@@ -0,0 +1,276 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI commands module.
+
+This module contains functions which implement main app
+functionality.
+
+Before running them from scripts 'logging' module should
+be configured. Function 'setup_logging' from module
+'mli.cli.logging' could be used for that, e.g.
+
+>>> from mlia.api import ExecutionContext
+>>> from mlia.cli.logging import setup_logging
+>>> setup_logging(verbose=True)
+>>> import mlia.cli.commands as mlia
+>>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "path/to/model")
+"""
+import logging
+from pathlib import Path
+from typing import cast
+from typing import List
+from typing import Optional
+
+from mlia.api import ExecutionContext
+from mlia.api import get_advice
+from mlia.api import PathOrFileLike
+from mlia.cli.config import get_installation_manager
+from mlia.cli.options import parse_optimization_parameters
+from mlia.devices.ethosu.operators import generate_supported_operators_report
+from mlia.utils.console import create_section_header
+from mlia.utils.types import only_one_selected
+
+logger = logging.getLogger(__name__)
+
+CONFIG = create_section_header("ML Inference Advisor configuration")
+
+
+def all_tests(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ optimization_type: str = "pruning,clustering",
+ optimization_target: str = "0.5,32",
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Generate a full report on the input model.
+
+ This command runs a series of tests in order to generate a
+ comprehensive report/advice:
+
+ - converts the input Keras model into TFLite format
+ - checks the model for operator compatibility on the specified device
+ - applies optimizations to the model and estimates the resulting performance
+ on both the original and the optimized models
+ - generates a final report on the steps above
+ - provides advice on how to (possibly) improve the inference performance
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the Keras model
+ :param optimization_type: list of the optimization techniques separated
+ by comma, e.g. 'pruning,clustering'
+ :param optimization_target: list of the corresponding targets for
+ the provided optimization techniques, e.g. '0.5,32'
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 with two model optimizations
+ and save report in json format locally in the file report.json
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import all_tests
+ >>> all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.h5", "pruning,clustering", "0.5,32",
+ output="report.json")
+ """
+ opt_params = parse_optimization_parameters(
+ optimization_type,
+ optimization_target,
+ )
+
+ get_advice(
+ target_profile,
+ model,
+ "all",
+ optimization_targets=opt_params,
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def operators(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: Optional[str] = None,
+ output: Optional[PathOrFileLike] = None,
+ supported_ops_report: bool = False,
+) -> None:
+ """Print the model's operator list.
+
+ This command checks the operator compatibility of the input model with
+ the specific target profile. Generates a report of the operator placement
+ (NPU or CPU fallback) and advice on how to improve it (if necessary).
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the model, which can be TFLite or Keras
+ :param output: path to the file where the report will be saved
+ :param supported_ops_report: if True then generates supported operators
+ report in current directory and exits
+
+ Example:
+ Run command for the target profile ethos-u55-256 and the provided
+ TFLite model and print report on the standard output
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import operators
+ >>> operators(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.tflite")
+ """
+ if supported_ops_report:
+ generate_supported_operators_report()
+ logger.info("Report saved into SUPPORTED_OPS.md")
+ return
+
+ if not model:
+ raise Exception("Model is not provided")
+
+ get_advice(
+ target_profile,
+ model,
+ "operators",
+ output=output,
+ context=ctx,
+ )
+
+
+def performance(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Print the model's performance stats.
+
+ This command estimates the inference performance of the input model
+ on the specified target profile, and generates a report with advice on how
+ to improve it.
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the model, which can be TFLite or Keras
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 and
+ the provided TFLite model and print report on the standard output
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import performance
+ >>> performance(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.tflite")
+ """
+ get_advice(
+ target_profile,
+ model,
+ "performance",
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def optimization(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ optimization_type: str,
+ optimization_target: str,
+ layers_to_optimize: Optional[List[str]] = None,
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Show the performance improvements (if any) after applying the optimizations.
+
+ This command applies the selected optimization techniques (up to the
+ indicated targets) and generates a report with advice on how to improve
+ the inference performance (if possible).
+
+ :param ctx: execution context
+ :param target: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the TFLite model
+ :param optimization_type: list of the optimization techniques separated
+ by comma, e.g. 'pruning,clustering'
+ :param optimization_target: list of the corresponding targets for
+ the provided optimization techniques, e.g. '0.5,32'
+ :param layers_to_optimize: list of the layers of the model which should be
+ optimized, if None then all layers are used
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 and
+ the provided TFLite model and print report on the standard output
+
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import optimization
+ >>> optimization(ExecutionContext(working_dir="mlia_output"),
+ target="ethos-u55-256",
+ "model.tflite", "pruning", "0.5")
+ """
+ opt_params = parse_optimization_parameters(
+ optimization_type,
+ optimization_target,
+ layers_to_optimize=layers_to_optimize,
+ )
+
+ get_advice(
+ target_profile,
+ model,
+ "optimization",
+ optimization_targets=opt_params,
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def backend(
+ backend_action: str,
+ path: Optional[Path] = None,
+ download: bool = False,
+ name: Optional[str] = None,
+ i_agree_to_the_contained_eula: bool = False,
+ noninteractive: bool = False,
+) -> None:
+ """Backends configuration."""
+ logger.info(CONFIG)
+
+ manager = get_installation_manager(noninteractive)
+
+ if backend_action == "status":
+ manager.show_env_details()
+
+ if backend_action == "install":
+ install_from_path = path is not None
+
+ if not only_one_selected(install_from_path, download):
+ raise Exception(
+ "Please select only one action: download or "
+ "provide path to the backend installation"
+ )
+
+ if install_from_path:
+ manager.install_from(cast(Path, path), name)
+
+ if download:
+ eula_agreement = not i_agree_to_the_contained_eula
+ manager.download_and_install(name, eula_agreement)
diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py
new file mode 100644
index 0000000..54bd457
--- /dev/null
+++ b/src/mlia/cli/common.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI common module."""
+import argparse
+from dataclasses import dataclass
+from typing import Callable
+from typing import List
+
+
+@dataclass
+class CommandInfo:
+ """Command description."""
+
+ func: Callable
+ aliases: List[str]
+ opt_groups: List[Callable[[argparse.ArgumentParser], None]]
+ is_default: bool = False
+
+ @property
+ def command_name(self) -> str:
+ """Return command name."""
+ return self.func.__name__
+
+ @property
+ def command_name_and_aliases(self) -> List[str]:
+ """Return list of command name and aliases."""
+ return [self.command_name, *self.aliases]
+
+ @property
+ def command_help(self) -> str:
+ """Return help message for the command."""
+ assert self.func.__doc__, "Command function does not have a docstring"
+ func_help = self.func.__doc__.splitlines()[0].rstrip(".")
+
+ if self.is_default:
+ func_help = f"{func_help} [default]"
+
+ return func_help
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
new file mode 100644
index 0000000..838b051
--- /dev/null
+++ b/src/mlia/cli/config.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Environment configuration functions."""
+import logging
+from functools import lru_cache
+from typing import List
+
+import mlia.tools.aiet_wrapper as aiet
+from mlia.tools.metadata.common import DefaultInstallationManager
+from mlia.tools.metadata.common import InstallationManager
+from mlia.tools.metadata.corstone import get_corstone_installations
+
+logger = logging.getLogger(__name__)
+
+
+def get_installation_manager(noninteractive: bool = False) -> InstallationManager:
+ """Return installation manager."""
+ backends = get_corstone_installations()
+
+ return DefaultInstallationManager(backends, noninteractive=noninteractive)
+
+
+@lru_cache
+def get_available_backends() -> List[str]:
+ """Return list of the available backends."""
+ available_backends = ["Vela"]
+
+ # Add backends using AIET
+ manager = get_installation_manager()
+ available_backends.extend(
+ (
+ backend
+ for backend in aiet.supported_backends()
+ if manager.backend_installed(backend)
+ )
+ )
+
+ return available_backends
+
+
+# List of mutually exclusive Corstone backends ordered by priority
+_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
+
+
+def get_default_backends() -> List[str]:
+ """Get default backends for evaluation."""
+ backends = get_available_backends()
+
+ # Filter backends to only include one Corstone backend
+ for corstone in _CORSTONE_EXCLUSIVE_PRIORITY:
+ if corstone in backends:
+ backends = [
+ backend
+ for backend in backends
+ if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY
+ ]
+ break
+
+ return backends
+
+
+def is_corstone_backend(backend: str) -> bool:
+ """Check if the given backend is a Corstone backend."""
+ return backend in _CORSTONE_EXCLUSIVE_PRIORITY
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
new file mode 100644
index 0000000..81d5a15
--- /dev/null
+++ b/src/mlia/cli/helpers.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for various helper classes."""
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia.cli.options import get_target_profile_opts
+from mlia.core.helpers import ActionResolver
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.utils.types import is_list_of
+
+
+class CLIActionResolver(ActionResolver):
+ """Helper class for generating cli commands."""
+
+ def __init__(self, args: Dict[str, Any]) -> None:
+ """Init action resolver."""
+ self.args = args
+
+ @staticmethod
+ def _general_optimization_command(model_path: Optional[str]) -> List[str]:
+ """Return general optimization command description."""
+ keras_note = []
+ if model_path is None or not is_keras_model(model_path):
+ model_path = "/path/to/keras_model"
+ keras_note = ["Note: you will need a Keras model for that."]
+
+ return [
+ *keras_note,
+ "For example: mlia optimization --optimization-type "
+ f"pruning,clustering --optimization-target 0.5,32 {model_path}",
+ "For more info: mlia optimization --help",
+ ]
+
+ @staticmethod
+ def _specific_optimization_command(
+ model_path: str,
+ device_opts: str,
+ opt_settings: List[OptimizationSettings],
+ ) -> List[str]:
+ """Return specific optimization command description."""
+ opt_types = ",".join(opt.optimization_type for opt in opt_settings)
+ opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
+
+ return [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ f"mlia optimization --optimization-type {opt_types} "
+ f"--optimization-target {opt_targs}{device_opts} {model_path}",
+ ]
+
+ def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ """Return command details for applying optimizations."""
+ model_path, device_opts = self._get_model_and_device_opts()
+
+ if (opt_settings := kwargs.pop("opt_settings", None)) is None:
+ return self._general_optimization_command(model_path)
+
+ if is_list_of(opt_settings, OptimizationSettings) and model_path:
+ return self._specific_optimization_command(
+ model_path, device_opts, opt_settings
+ )
+
+ return []
+
+ def supported_operators_info(self) -> List[str]:
+ """Return command details for generating supported ops report."""
+ return [
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+
+ def check_performance(self) -> List[str]:
+ """Return command details for checking performance."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Check the estimated performance by running the following command: ",
+ f"mlia performance{device_opts} {model_path}",
+ ]
+
+ def check_operator_compatibility(self) -> List[str]:
+ """Return command details for op compatibility."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Try running the following command to verify that:",
+ f"mlia operators{device_opts} {model_path}",
+ ]
+
+ def operator_compatibility_details(self) -> List[str]:
+ """Return command details for op compatibility."""
+ return ["For more details, run: mlia operators --help"]
+
+ def optimization_details(self) -> List[str]:
+ """Return command details for optimization."""
+ return ["For more info, see: mlia optimization --help"]
+
+ def _get_model_and_device_opts(
+ self, separate_device_opts: bool = True
+ ) -> Tuple[Optional[str], str]:
+ """Get model and device options."""
+ device_opts = " ".join(get_target_profile_opts(self.args))
+ if separate_device_opts and device_opts:
+ device_opts = f" {device_opts}"
+
+ model_path = self.args.get("model")
+ return model_path, device_opts
diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py
new file mode 100644
index 0000000..c5fc7bd
--- /dev/null
+++ b/src/mlia/cli/logging.py
@@ -0,0 +1,117 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI logging configuration."""
+import logging
+import sys
+from pathlib import Path
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mlia.utils.logging import attach_handlers
+from mlia.utils.logging import create_log_handler
+from mlia.utils.logging import LogFilter
+
+
+_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s"
+_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+
+
+def setup_logging(
+ logs_dir: Optional[Union[str, Path]] = None,
+ verbose: bool = False,
+ log_filename: str = "mlia.log",
+) -> None:
+ """Set up logging.
+
+ MLIA uses module 'logging' when it needs to produce output.
+
+ :param logs_dir: path to the directory where application will save logs with
+ debug information. If the path is not provided then no log files will
+ be created during execution
+ :param verbose: enable extended logging for the tools loggers
+ :param log_filename: name of the log file in the logs directory
+ """
+ mlia_logger, *tools_loggers = [
+ logging.getLogger(logger_name)
+ for logger_name in ["mlia", "tensorflow", "py.warnings"]
+ ]
+
+ # enable debug output, actual message filtering depends on
+ # the provided parameters and being done on the handlers level
+ mlia_logger.setLevel(logging.DEBUG)
+
+ mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
+ attach_handlers(mlia_handlers, [mlia_logger])
+
+ tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
+ attach_handlers(tools_handlers, tools_loggers)
+
+
+def _get_mlia_handlers(
+ logs_dir: Optional[Union[str, Path]],
+ log_filename: str,
+ verbose: bool,
+) -> List[logging.Handler]:
+ """Get handlers for the MLIA loggers."""
+ result = []
+ stdout_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.INFO,
+ )
+ result.append(stdout_handler)
+
+ if verbose:
+ mlia_verbose_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.DEBUG,
+ log_format=_CONSOLE_DEBUG_FORMAT,
+ log_filter=LogFilter.equals(logging.DEBUG),
+ )
+ result.append(mlia_verbose_handler)
+
+ if logs_dir:
+ mlia_file_handler = create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=logging.DEBUG,
+ log_format=_FILE_DEBUG_FORMAT,
+ log_filter=LogFilter.skip(logging.INFO),
+ delay=True,
+ )
+ result.append(mlia_file_handler)
+
+ return result
+
+
+def _get_tools_handlers(
+ logs_dir: Optional[Union[str, Path]],
+ log_filename: str,
+ verbose: bool,
+) -> List[logging.Handler]:
+ """Get handler for the tools loggers."""
+ result = []
+ if verbose:
+ verbose_stdout_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.DEBUG,
+ log_format=_CONSOLE_DEBUG_FORMAT,
+ )
+ result.append(verbose_stdout_handler)
+
+ if logs_dir:
+ file_handler = create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=logging.DEBUG,
+ log_format=_FILE_DEBUG_FORMAT,
+ delay=True,
+ )
+ result.append(file_handler)
+
+ return result
+
+
+def _get_log_file(logs_dir: Union[str, Path], log_filename: str) -> Path:
+ """Get the log file path."""
+ logs_dir_path = Path(logs_dir)
+ logs_dir_path.mkdir(exist_ok=True)
+ return logs_dir_path / log_filename
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
new file mode 100644
index 0000000..33fcdeb
--- /dev/null
+++ b/src/mlia/cli/main.py
@@ -0,0 +1,280 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI main entry point."""
+import argparse
+import logging
+import sys
+from inspect import signature
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia import __version__
+from mlia.cli.commands import all_tests
+from mlia.cli.commands import backend
+from mlia.cli.commands import operators
+from mlia.cli.commands import optimization
+from mlia.cli.commands import performance
+from mlia.cli.common import CommandInfo
+from mlia.cli.helpers import CLIActionResolver
+from mlia.cli.logging import setup_logging
+from mlia.cli.options import add_backend_options
+from mlia.cli.options import add_custom_supported_operators_options
+from mlia.cli.options import add_debug_options
+from mlia.cli.options import add_evaluation_options
+from mlia.cli.options import add_keras_model_options
+from mlia.cli.options import add_multi_optimization_options
+from mlia.cli.options import add_optional_tflite_model_options
+from mlia.cli.options import add_output_options
+from mlia.cli.options import add_target_options
+from mlia.cli.options import add_tflite_model_options
+from mlia.core.context import ExecutionContext
+
+
+logger = logging.getLogger(__name__)
+
+INFO_MESSAGE = f"""
+ML Inference Advisor {__version__}
+
+Help the design and optimization of neural network models for efficient inference on a target CPU, GPU and NPU
+
+Supported targets:
+
+ - Ethos-U55 <op compatibility, perf estimation, model opt>
+ - Ethos-U65 <op compatibility, perf estimation, model opt>
+
+""".strip()
+
+
+def get_commands() -> List[CommandInfo]:
+ """Return commands configuration."""
+ return [
+ CommandInfo(
+ all_tests,
+ ["all"],
+ [
+ add_target_options,
+ add_keras_model_options,
+ add_multi_optimization_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ True,
+ ),
+ CommandInfo(
+ operators,
+ ["ops"],
+ [
+ add_target_options,
+ add_optional_tflite_model_options,
+ add_output_options,
+ add_custom_supported_operators_options,
+ add_debug_options,
+ ],
+ ),
+ CommandInfo(
+ performance,
+ ["perf"],
+ [
+ add_target_options,
+ add_tflite_model_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ ),
+ CommandInfo(
+ optimization,
+ ["opt"],
+ [
+ add_target_options,
+ add_keras_model_options,
+ add_multi_optimization_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ ),
+ CommandInfo(
+ backend,
+ [],
+ [
+ add_backend_options,
+ add_debug_options,
+ ],
+ ),
+ ]
+
+
+def get_default_command() -> Optional[str]:
+ """Get name of the default command."""
+ commands = get_commands()
+
+ marked_as_default = [cmd.command_name for cmd in commands if cmd.is_default]
+ assert len(marked_as_default) <= 1, "Only one command could be marked as default"
+
+ return next(iter(marked_as_default), None)
+
+
+def get_possible_command_names() -> List[str]:
+ """Get all possible command names including aliases."""
+ return [
+ name_or_alias
+ for cmd in get_commands()
+ for name_or_alias in cmd.command_name_and_aliases
+ ]
+
+
+def init_commands(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Init cli subcommands."""
+ subparsers = parser.add_subparsers(title="Commands", dest="command")
+ subparsers.required = True
+
+ for command in get_commands():
+ command_parser = subparsers.add_parser(
+ command.command_name,
+ aliases=command.aliases,
+ help=command.command_help,
+ allow_abbrev=False,
+ )
+ command_parser.set_defaults(func=command.func)
+ for opt_group in command.opt_groups:
+ opt_group(command_parser)
+
+ return parser
+
+
+def setup_context(
+ args: argparse.Namespace, context_var_name: str = "ctx"
+) -> Tuple[ExecutionContext, Dict]:
+ """Set up context and resolve function parameters."""
+ ctx = ExecutionContext(
+ working_dir=args.working_dir,
+ verbose="verbose" in args and args.verbose,
+ action_resolver=CLIActionResolver(vars(args)),
+ )
+
+ # these parameters should not be passed into command function
+ skipped_params = ["func", "command", "working_dir", "verbose"]
+
+ # pass these parameters only if command expects them
+ expected_params = [context_var_name]
+ func_params = signature(args.func).parameters
+
+ params = {context_var_name: ctx, **vars(args)}
+
+ func_args = {
+ param_name: param_value
+ for param_name, param_value in params.items()
+ if param_name not in skipped_params
+ and (param_name not in expected_params or param_name in func_params)
+ }
+
+ return (ctx, func_args)
+
+
+def run_command(args: argparse.Namespace) -> int:
+ """Run command."""
+ ctx, func_args = setup_context(args)
+ setup_logging(ctx.logs_path, ctx.verbose)
+
+ logger.debug(
+ "*** This is the beginning of the command '%s' execution ***", args.command
+ )
+
+ try:
+ logger.info(INFO_MESSAGE)
+
+ args.func(**func_args)
+ return 0
+ except KeyboardInterrupt:
+ logger.error("Execution has been interrupted")
+ except Exception as err: # pylint: disable=broad-except
+ logger.error(
+ "\nExecution finished with error: %s",
+ err,
+ exc_info=err if ctx.verbose else None,
+ )
+
+ err_advice_message = (
+ f"Please check the log files in the {ctx.logs_path} for more details"
+ )
+ if not ctx.verbose:
+ err_advice_message += ", or enable verbose mode"
+
+ logger.error(err_advice_message)
+
+ return 1
+
+
+def init_common_parser() -> argparse.ArgumentParser:
+ """Init common parser."""
+ parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
+ parser.add_argument(
+ "--working-dir",
+ default=f"{Path.cwd() / 'mlia_output'}",
+ help="Path to the directory where MLIA will store logs, "
+ "models, etc. (default: %(default)s)",
+ )
+
+ return parser
+
+
+def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Init subcommand parser."""
+ parser = argparse.ArgumentParser(
+ description=INFO_MESSAGE,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ parents=[parent],
+ add_help=False,
+ allow_abbrev=False,
+ )
+ parser.add_argument(
+ "-h",
+ "--help",
+ action="help",
+ default=argparse.SUPPRESS,
+ help="Show this help message and exit",
+ )
+ parser.add_argument(
+ "-v",
+ "--version",
+ action="version",
+ version=f"%(prog)s {__version__}",
+ help="Show program's version number and exit",
+ )
+
+ return parser
+
+
+def add_default_command_if_needed(args: List[str]) -> None:
+ """Add default command to the list of the arguments if needed."""
+ default_command = get_default_command()
+
+ if default_command and len(args) > 0:
+ commands = get_possible_command_names()
+ help_or_version = ["-h", "--help", "-v", "--version"]
+
+ command_is_missing = args[0] not in [*commands, *help_or_version]
+ if command_is_missing:
+ args.insert(0, default_command)
+
+
+def main(argv: Optional[List[str]] = None) -> int:
+ """Entry point of the application."""
+ common_parser = init_common_parser()
+ subcommand_parser = init_subcommand_parser(common_parser)
+ init_commands(subcommand_parser)
+
+ common_args, subcommand_args = common_parser.parse_known_args(argv)
+ add_default_command_if_needed(subcommand_args)
+
+ args = subcommand_parser.parse_args(subcommand_args, common_args)
+ return run_command(args)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
new file mode 100644
index 0000000..dc5cb73
--- /dev/null
+++ b/src/mlia/cli/options.py
@@ -0,0 +1,280 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the CLI options."""
+import argparse
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from mlia.cli.config import get_available_backends
+from mlia.cli.config import get_default_backends
+from mlia.cli.config import is_corstone_backend
+from mlia.utils.filesystem import get_supported_profile_names
+from mlia.utils.types import is_number
+
+
+def add_target_options(parser: argparse.ArgumentParser) -> None:
+ """Add target specific options."""
+ target_profiles = get_supported_profile_names()
+
+ default_target_profile = None
+ default_help = ""
+ if target_profiles:
+ default_target_profile = target_profiles[0]
+ default_help = " (default: %(default)s)"
+
+ target_group = parser.add_argument_group("target options")
+ target_group.add_argument(
+ "--target-profile",
+ choices=target_profiles,
+ default=default_target_profile,
+ help="Target profile that will set the target options "
+ "such as target, mac value, memory mode, etc. "
+ f"For the values associated with each target profile "
+ f" please refer to the documenation {default_help}.",
+ )
+
+
+def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
+ """Add optimization specific options."""
+ multi_optimization_group = parser.add_argument_group("optimization options")
+
+ multi_optimization_group.add_argument(
+ "--optimization-type",
+ default="pruning,clustering",
+ help="List of the optimization types separated by comma (default: %(default)s)",
+ )
+ multi_optimization_group.add_argument(
+ "--optimization-target",
+ default="0.5,32",
+ help="""List of the optimization targets separated by comma,
+ (for pruning this is sparsity between (0,1),
+ for clustering this is the number of clusters (positive integer))
+ (default: %(default)s)""",
+ )
+
+
+def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add optional model specific options."""
+ model_group = parser.add_argument_group("TFLite model options")
+ # make model parameter optional
+ model_group.add_argument("model", nargs="?", help="TFLite model (optional)")
+
+
+def add_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add model specific options."""
+ model_group = parser.add_argument_group("TFLite model options")
+ model_group.add_argument("model", help="TFLite model")
+
+
+def add_output_options(parser: argparse.ArgumentParser) -> None:
+ """Add output specific options."""
+ valid_extensions = ["csv", "json"]
+
+ def check_extension(filename: str) -> str:
+ """Check extension of the provided file."""
+ suffix = Path(filename).suffix
+ if suffix.startswith("."):
+ suffix = suffix[1:]
+
+ if suffix.lower() not in valid_extensions:
+ parser.error(f"Unsupported format '{suffix}'")
+
+ return filename
+
+ output_group = parser.add_argument_group("output options")
+ output_group.add_argument(
+ "--output",
+ type=check_extension,
+ help=(
+ "Name of the file where report will be saved. "
+ "Report format is automatically detected based on the file extension. "
+ f"Supported formats are: {', '.join(valid_extensions)}"
+ ),
+ )
+
+
+def add_debug_options(parser: argparse.ArgumentParser) -> None:
+ """Add debug options."""
+ debug_group = parser.add_argument_group("debug options")
+ debug_group.add_argument(
+ "--verbose", default=False, action="store_true", help="Produce verbose output"
+ )
+
+
+def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add model specific options."""
+ model_group = parser.add_argument_group("Keras model options")
+ model_group.add_argument("model", help="Keras model")
+
+
+def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None:
+ """Add custom options for the command 'operators'."""
+ parser.add_argument(
+ "--supported-ops-report",
+ action="store_true",
+ default=False,
+ help=(
+ "Generate the SUPPORTED_OPS.md file in the "
+ "current working directory and exit"
+ ),
+ )
+
+
+def add_backend_options(parser: argparse.ArgumentParser) -> None:
+ """Add options for the backends configuration."""
+
+ def valid_directory(param: str) -> Path:
+ """Check if passed string is a valid directory path."""
+ if not (dir_path := Path(param)).is_dir():
+ parser.error(f"Invalid directory path {param}")
+
+ return dir_path
+
+ subparsers = parser.add_subparsers(title="Backend actions", dest="backend_action")
+ subparsers.required = True
+
+ install_subparser = subparsers.add_parser(
+ "install", help="Install backend", allow_abbrev=False
+ )
+ install_type_group = install_subparser.add_mutually_exclusive_group()
+ install_type_group.required = True
+ install_type_group.add_argument(
+ "--path", type=valid_directory, help="Path to the installed backend"
+ )
+ install_type_group.add_argument(
+ "--download",
+ default=False,
+ action="store_true",
+ help="Download and install backend",
+ )
+ install_subparser.add_argument(
+ "--i-agree-to-the-contained-eula",
+ default=False,
+ action="store_true",
+ help=argparse.SUPPRESS,
+ )
+ install_subparser.add_argument(
+ "--noninteractive",
+ default=False,
+ action="store_true",
+ help="Non interactive mode with automatic confirmation of every action",
+ )
+ install_subparser.add_argument(
+ "name",
+ nargs="?",
+ help="Name of the backend to install",
+ )
+
+ subparsers.add_parser("status", help="Show backends status")
+
+
+def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
+ """Add evaluation options."""
+ available_backends = get_available_backends()
+ default_backends = get_default_backends()
+
+ def only_one_corstone_checker() -> Callable:
+ """
+ Return a callable to check that only one Corstone backend is passed.
+
+ Raises an exception when more than one Corstone backend is passed.
+ """
+ num_corstones = 0
+
+ def check(backend: str) -> str:
+ """Count Corstone backends and raise an exception if more than one."""
+ nonlocal num_corstones
+ if is_corstone_backend(backend):
+ num_corstones = num_corstones + 1
+ if num_corstones > 1:
+ raise argparse.ArgumentTypeError(
+ "There must be only one Corstone backend in the argument list."
+ )
+ return backend
+
+ return check
+
+ evaluation_group = parser.add_argument_group("evaluation options")
+ evaluation_group.add_argument(
+ "--evaluate-on",
+ help="Backends to use for evaluation (default: %(default)s)",
+ nargs="*",
+ choices=available_backends,
+ default=default_backends,
+ type=only_one_corstone_checker(),
+ )
+
+
+def parse_optimization_parameters(
+ optimization_type: str,
+ optimization_target: str,
+ sep: str = ",",
+ layers_to_optimize: Optional[List[str]] = None,
+) -> List[Dict[str, Any]]:
+ """Parse provided optimization parameters."""
+ if not optimization_type:
+ raise Exception("Optimization type is not provided")
+
+ if not optimization_target:
+ raise Exception("Optimization target is not provided")
+
+ opt_types = optimization_type.split(sep)
+ opt_targets = optimization_target.split(sep)
+
+ if len(opt_types) != len(opt_targets):
+ raise Exception("Wrong number of optimization targets and types")
+
+ non_numeric_targets = [
+ opt_target for opt_target in opt_targets if not is_number(opt_target)
+ ]
+ if len(non_numeric_targets) > 0:
+ raise Exception("Non numeric value for the optimization target")
+
+ optimizer_params = [
+ {
+ "optimization_type": opt_type.strip(),
+ "optimization_target": float(opt_target),
+ "layers_to_optimize": layers_to_optimize,
+ }
+ for opt_type, opt_target in zip(opt_types, opt_targets)
+ ]
+
+ return optimizer_params
+
+
+def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
+ """Get non default values passed as parameters for the target profile."""
+ if not device_args:
+ return []
+
+ dummy_parser = argparse.ArgumentParser()
+ add_target_options(dummy_parser)
+ args = dummy_parser.parse_args([])
+
+ params_name = {
+ action.dest: param_name
+ for param_name, action in dummy_parser._option_string_actions.items() # pylint: disable=protected-access
+ }
+
+ non_default = [
+ arg_name
+ for arg_name, arg_value in device_args.items()
+ if arg_name in args and vars(args)[arg_name] != arg_value
+ ]
+
+ def construct_param(name: str, value: Any) -> List[str]:
+ """Construct parameter."""
+ if isinstance(value, list):
+ return [str(item) for v in value for item in [name, v]]
+
+ return [name, str(value)]
+
+ return [
+ item
+ for name in non_default
+ for item in construct_param(params_name[name], device_args[name])
+ ]
diff --git a/src/mlia/core/__init__.py b/src/mlia/core/__init__.py
new file mode 100644
index 0000000..49b1830
--- /dev/null
+++ b/src/mlia/core/__init__.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Core module.
+
+Core module contains the main components that are used in the workflow of
+ML Inference Advisor:
+ - data collectors
+ - data analyzers
+ - advice producers
+ - event publishers
+ - event handlers
+
+The workflow of ML Inference Advisor consists of 3 stages:
+ - data collection
+ - data analysis
+ - advice generation
+
+Data is being passed from one stage to another via workflow executor.
+Results (collected data, analyzed data, advice, etc) are being published via
+publish/subscribe mechanishm.
+"""
diff --git a/src/mlia/core/_typing.py b/src/mlia/core/_typing.py
new file mode 100644
index 0000000..bda995c
--- /dev/null
+++ b/src/mlia/core/_typing.py
@@ -0,0 +1,12 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for custom type hints."""
+from pathlib import Path
+from typing import Literal
+from typing import TextIO
+from typing import Union
+
+
+FileLike = TextIO
+PathOrFileLike = Union[str, Path, FileLike]
+OutputFormat = Literal["plain_text", "csv", "json"]
diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py
new file mode 100644
index 0000000..76cc1f2
--- /dev/null
+++ b/src/mlia/core/advice_generation.py
@@ -0,0 +1,106 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for advice generation."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from functools import wraps
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Union
+
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.events import SystemEvent
+from mlia.core.mixins import ContextMixin
+
+
+@dataclass
+class Advice:
+ """Base class for the advice."""
+
+ messages: List[str]
+
+
+@dataclass
+class AdviceEvent(SystemEvent):
+ """Advice event.
+
+ This event is published for every produced advice.
+
+ :param advice: Advice instance
+ """
+
+ advice: Advice
+
+
+class AdviceProducer(ABC):
+ """Base class for the advice producer.
+
+ Producer has two methods for advice generation:
+ - produce_advice - used to generate advice based on provided
+ data (analyzed data item from analyze stage)
+ - get_advice - used for getting generated advice
+
+ Advice producers that have predefined advice could skip
+ implementation of produce_advice method.
+ """
+
+ @abstractmethod
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Process data item and produce advice.
+
+ :param data_item: piece of data that could be used
+ for advice generation
+ """
+
+ @abstractmethod
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Get produced advice."""
+
+
+class ContextAwareAdviceProducer(AdviceProducer, ContextMixin):
+ """Context aware advice producer.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
+
+
+class FactBasedAdviceProducer(ContextAwareAdviceProducer):
+ """Advice producer based on provided facts.
+
+ This is an utility class that maintain list of generated Advice instances.
+ """
+
+ def __init__(self) -> None:
+ """Init advice producer."""
+ self.advice: List[Advice] = []
+
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Get produced advice."""
+ return self.advice
+
+ def add_advice(self, messages: List[str]) -> None:
+ """Add advice."""
+ self.advice.append(Advice(messages))
+
+
+def advice_category(*categories: AdviceCategory) -> Callable:
+ """Filter advice generation handler by advice category."""
+
+ def wrapper(handler: Callable) -> Callable:
+ """Wrap data handler."""
+
+ @wraps(handler)
+ def check_category(self: Any, *args: Any, **kwargs: Any) -> Any:
+ """Check if handler can produce advice for the requested category."""
+ if not self.context.any_category_enabled(*categories):
+ return
+
+ handler(self, *args, **kwargs)
+
+ return check_category
+
+ return wrapper
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py
new file mode 100644
index 0000000..868d0c7
--- /dev/null
+++ b/src/mlia/core/advisor.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Inference advisor module."""
+from abc import abstractmethod
+
+from mlia.core.common import NamedEntity
+from mlia.core.context import Context
+from mlia.core.workflow import WorkflowExecutor
+
+
+class InferenceAdvisor(NamedEntity):
+ """Base class for inference advisors."""
+
+ @abstractmethod
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor execution."""
+
+ def run(self, context: Context) -> None:
+ """Run inference advisor."""
+ executor = self.configure(context)
+ executor.run()
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
new file mode 100644
index 0000000..5fbad42
--- /dev/null
+++ b/src/mlia/core/common.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common module.
+
+This module contains common interfaces/classess shared across
+core module.
+"""
+from abc import ABC
+from abc import abstractmethod
+from enum import Enum
+from typing import Any
+
+# This type is used as type alias for the items which are being passed around
+# in advisor workflow. There are no restrictions on the type of the
+# object. This alias used only to emphasize the nature of the input/output
+# arguments.
+DataItem = Any
+
+
+class AdviceCategory(Enum):
+ """Advice category.
+
+ Enumeration of advice categories supported by ML Inference Advisor.
+ """
+
+ OPERATORS = 1
+ PERFORMANCE = 2
+ OPTIMIZATION = 3
+ ALL = 4
+
+ @classmethod
+ def from_string(cls, value: str) -> "AdviceCategory":
+ """Resolve enum value from string value."""
+ category_names = [item.name for item in AdviceCategory]
+ if not value or value.upper() not in category_names:
+ raise Exception(f"Invalid advice category {value}")
+
+ return AdviceCategory[value.upper()]
+
+
+class NamedEntity(ABC):
+ """Entity with a name and description."""
+
+ @classmethod
+ @abstractmethod
+ def name(cls) -> str:
+ """Return name of the entity."""
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
new file mode 100644
index 0000000..8b3dd2c
--- /dev/null
+++ b/src/mlia/core/context.py
@@ -0,0 +1,218 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Context module.
+
+This module contains functionality related to the Context.
+Context is an object that describes advisor working environment
+and requested behavior (advice categories, input configuration
+parameters).
+"""
+import logging
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Union
+
+from mlia.core.common import AdviceCategory
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import EventHandler
+from mlia.core.events import EventPublisher
+from mlia.core.helpers import ActionResolver
+from mlia.core.helpers import APIActionResolver
+
+logger = logging.getLogger(__name__)
+
+
+class Context(ABC):
+ """Abstract class for the execution context."""
+
+ @abstractmethod
+ def get_model_path(self, model_filename: str) -> Path:
+ """Return path for the intermediate/optimized models.
+
+ During workflow execution different parts of the advisor
+ require creating intermediate files for models.
+
+ This method allows to provide paths where those models
+ could be saved.
+
+ :param model_filename: filename of the model
+ """
+
+ @property
+ @abstractmethod
+ def event_publisher(self) -> EventPublisher:
+ """Return event publisher."""
+
+ @property
+ @abstractmethod
+ def event_handlers(self) -> Optional[List[EventHandler]]:
+ """Return list of the event_handlers."""
+
+ @property
+ @abstractmethod
+ def advice_category(self) -> Optional[AdviceCategory]:
+ """Return advice category."""
+
+ @property
+ @abstractmethod
+ def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ """Return configuration parameters."""
+
+ @property
+ @abstractmethod
+ def action_resolver(self) -> ActionResolver:
+ """Return action resolver."""
+
+ @abstractmethod
+ def update(
+ self,
+ *,
+ advice_category: AdviceCategory,
+ event_handlers: List[EventHandler],
+ config_parameters: Mapping[str, Any],
+ ) -> None:
+ """Update context parameters."""
+
+ def category_enabled(self, category: AdviceCategory) -> bool:
+ """Check if category enabled."""
+ return category == self.advice_category
+
+ def any_category_enabled(self, *categories: AdviceCategory) -> bool:
+ """Return true if any category is enabled."""
+ return self.advice_category in categories
+
+ def register_event_handlers(self) -> None:
+ """Register event handlers."""
+ self.event_publisher.register_event_handlers(self.event_handlers)
+
+
+class ExecutionContext(Context):
+ """Execution context."""
+
+ def __init__(
+ self,
+ *,
+ advice_category: Optional[AdviceCategory] = None,
+ config_parameters: Optional[Mapping[str, Any]] = None,
+ working_dir: Optional[Union[str, Path]] = None,
+ event_handlers: Optional[List[EventHandler]] = None,
+ event_publisher: Optional[EventPublisher] = None,
+ verbose: bool = False,
+ logs_dir: str = "logs",
+ models_dir: str = "models",
+ action_resolver: Optional[ActionResolver] = None,
+ ) -> None:
+ """Init execution context.
+
+ :param advice_category: requested advice category
+ :param config_parameters: dictionary like object with input parameters
+ :param working_dir: path to the directory that will be used as a place
+ to store temporary files, logs, models. If not provided then
+ current working directory will be used instead
+ :param event_handlers: optional list of event handlers
+ :param event_publisher: optional event publisher instance. If not provided
+ then default implementation of event publisher will be used
+ :param verbose: enable verbose output
+ :param logs_dir: name of the directory inside working directory where
+ log files will be stored
+ :param models_dir: name of the directory inside working directory where
+ temporary models will be stored
+ :param action_resolver: instance of the action resolver that could make
+ advice actionable
+ """
+ self._advice_category = advice_category
+ self._config_parameters = config_parameters
+
+ self._working_dir_path = Path.cwd()
+ if working_dir:
+ self._working_dir_path = Path(working_dir)
+ self._working_dir_path.mkdir(exist_ok=True)
+
+ self._event_handlers = event_handlers
+ self._event_publisher = event_publisher or DefaultEventPublisher()
+ self.verbose = verbose
+ self.logs_dir = logs_dir
+ self.models_dir = models_dir
+ self._action_resolver = action_resolver or APIActionResolver()
+
+ @property
+ def advice_category(self) -> Optional[AdviceCategory]:
+ """Return advice category."""
+ return self._advice_category
+
+ @advice_category.setter
+ def advice_category(self, advice_category: AdviceCategory) -> None:
+ """Setter for the advice category."""
+ self._advice_category = advice_category
+
+ @property
+ def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ """Return configuration parameters."""
+ return self._config_parameters
+
+ @config_parameters.setter
+ def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None:
+ """Setter for the configuration parameters."""
+ self._config_parameters = config_parameters
+
+ @property
+ def event_handlers(self) -> Optional[List[EventHandler]]:
+ """Return list of the event handlers."""
+ return self._event_handlers
+
+ @event_handlers.setter
+ def event_handlers(self, event_handlers: List[EventHandler]) -> None:
+ """Setter for the event handlers."""
+ self._event_handlers = event_handlers
+
+ @property
+ def event_publisher(self) -> EventPublisher:
+ """Return event publisher."""
+ return self._event_publisher
+
+ @property
+ def action_resolver(self) -> ActionResolver:
+ """Return action resolver."""
+ return self._action_resolver
+
+ def get_model_path(self, model_filename: str) -> Path:
+ """Return path for the model."""
+ models_dir_path = self._working_dir_path / self.models_dir
+ models_dir_path.mkdir(exist_ok=True)
+
+ return models_dir_path / model_filename
+
+ @property
+ def logs_path(self) -> Path:
+ """Return path to the logs directory."""
+ return self._working_dir_path / self.logs_dir
+
+ def update(
+ self,
+ *,
+ advice_category: AdviceCategory,
+ event_handlers: List[EventHandler],
+ config_parameters: Mapping[str, Any],
+ ) -> None:
+ """Update context parameters."""
+ self._advice_category = advice_category
+ self._event_handlers = event_handlers
+ self._config_parameters = config_parameters
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ category = (
+ "<not set>" if self.advice_category is None else self.advice_category.name
+ )
+
+ return (
+ f"ExecutionContext: working_dir={self._working_dir_path}, "
+ f"advice_category={category}, "
+ f"config_parameters={self.config_parameters}, "
+ f"verbose={self.verbose}"
+ )
diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py
new file mode 100644
index 0000000..6adb41e
--- /dev/null
+++ b/src/mlia/core/data_analysis.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for data analysis."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import List
+
+from mlia.core.common import DataItem
+from mlia.core.mixins import ContextMixin
+
+
+class DataAnalyzer(ABC):
+ """Base class for the data analysis.
+
+ Purpose of this class is to extract valuable data out of
+ collected data which could be used for advice generation.
+
+ This process consists of two steps:
+ - analyze every item of the collected data
+ - get analyzed data
+ """
+
+ @abstractmethod
+ def analyze_data(self, data_item: DataItem) -> None:
+ """Analyze data.
+
+ :param data_item: item of the collected data
+ """
+
+ @abstractmethod
+ def get_analyzed_data(self) -> List[DataItem]:
+ """Get analyzed data."""
+
+
+class ContextAwareDataAnalyzer(DataAnalyzer, ContextMixin):
+ """Context aware data analyzer.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
+
+
+@dataclass
+class Fact:
+ """Base class for the facts.
+
+ Fact represents some piece of knowledge about collected
+ data.
+ """
+
+
+class FactExtractor(ContextAwareDataAnalyzer):
+ """Data analyzer based on extracting facts.
+
+ Utility class that makes fact extraction easier.
+ Class maintains list of the extracted facts.
+ """
+
+ def __init__(self) -> None:
+ """Init fact extractor."""
+ self.facts: List[Fact] = []
+
+ def get_analyzed_data(self) -> List[DataItem]:
+ """Return list of the collected facts."""
+ return self.facts
+
+ def add_fact(self, fact: Fact) -> None:
+ """Add fact."""
+ self.facts.append(fact)
diff --git a/src/mlia/core/data_collection.py b/src/mlia/core/data_collection.py
new file mode 100644
index 0000000..43b6d1c
--- /dev/null
+++ b/src/mlia/core/data_collection.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for data collection.
+
+This module contains base classes for the first stage
+of the ML Inference Advisor workflow - data collection.
+"""
+from abc import abstractmethod
+
+from mlia.core.common import DataItem
+from mlia.core.common import NamedEntity
+from mlia.core.mixins import ContextMixin
+from mlia.core.mixins import ParameterResolverMixin
+
+
+class DataCollector(NamedEntity):
+ """Base class for the data collection.
+
+ Data collection is the first step in the process of the advice
+ generation.
+
+ Different implementations of this class can provide various
+ information about model or device. This information is being used
+ at later stages.
+ """
+
+ @abstractmethod
+ def collect_data(self) -> DataItem:
+ """Collect data."""
+
+
+class ContextAwareDataCollector(DataCollector, ContextMixin, ParameterResolverMixin):
+ """Context aware data collector.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
diff --git a/src/mlia/core/errors.py b/src/mlia/core/errors.py
new file mode 100644
index 0000000..7d6beb1
--- /dev/null
+++ b/src/mlia/core/errors.py
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""MLIA exceptions module."""
+
+
+class ConfigurationError(Exception):
+ """Configuration error."""
+
+
+class FunctionalityNotSupportedError(Exception):
+ """Functionality is not supported error."""
+
+ def __init__(self, reason: str, description: str) -> None:
+ """Init exception."""
+ super().__init__(f"{reason}: {description}")
+
+ self.reason = reason
+ self.description = description
diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py
new file mode 100644
index 0000000..10aec86
--- /dev/null
+++ b/src/mlia/core/events.py
@@ -0,0 +1,455 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the events and related functionality.
+
+This module represents one of the main component of the workflow -
+events publishing and provides a way for delivering results to the
+calling application.
+
+Each component of the workflow can generate events of specific type.
+Application can subscribe and react to those events.
+"""
+import traceback
+import uuid
+from abc import ABC
+from abc import abstractmethod
+from contextlib import contextmanager
+from dataclasses import asdict
+from dataclasses import dataclass
+from dataclasses import field
+from functools import singledispatchmethod
+from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia.core.common import DataItem
+
+
+@dataclass
+class Event:
+ """Base class for the events.
+
+ This class is used as a root node of the events class hierarchy.
+ """
+
+ event_id: str = field(init=False)
+
+ def __post_init__(self) -> None:
+ """Generate unique ID for the event."""
+ self.event_id = str(uuid.uuid4())
+
+ def compare_without_id(self, other: "Event") -> bool:
+ """Compare two events without event_id field."""
+ if not isinstance(other, Event) or self.__class__ != other.__class__:
+ return False
+
+ self_as_dict = asdict(self)
+ self_as_dict.pop("event_id")
+
+ other_as_dict = asdict(other)
+ other_as_dict.pop("event_id")
+
+ return self_as_dict == other_as_dict
+
+
+@dataclass
+class ChildEvent(Event):
+ """Child event.
+
+ This class could be used to link event with the parent event.
+ """
+
+ parent_event_id: str
+
+
+@dataclass
+class ActionStartedEvent(Event):
+ """Action started event.
+
+ This event is published when some action has been started.
+ """
+
+ action_type: str
+ params: Optional[Dict] = None
+
+
+@dataclass
+class SubActionEvent(ChildEvent):
+ """SubAction event.
+
+ This event could be used to represent some action during parent action.
+ """
+
+ action_type: str
+ params: Optional[Dict] = None
+
+
+@dataclass
+class ActionFinishedEvent(ChildEvent):
+ """Action finished event.
+
+ This event is published when some action has been finished.
+ """
+
+
+@dataclass
+class SystemEvent(Event):
+ """System event.
+
+ System event class represents events that published by components
+ of the core module. Most common example is an workflow executor
+ that publishes number of system events for starting/completion
+ of different stages/workflow.
+
+ Events that published by components outside of core module should not
+ use this class as base class.
+ """
+
+
+@dataclass
+class ExecutionStartedEvent(SystemEvent):
+ """Execution started event.
+
+ This event is published when workflow execution started.
+ """
+
+
+@dataclass
+class ExecutionFinishedEvent(SystemEvent):
+ """Execution finished event.
+
+ This event is published when workflow execution finished.
+ """
+
+
+@dataclass
+class ExecutionFailedEvent(SystemEvent):
+ """Execution failed event."""
+
+ err: Exception
+
+
+@dataclass
+class DataCollectionStageStartedEvent(SystemEvent):
+ """Data collection stage started.
+
+ This event is published when data collection stage started.
+ """
+
+
+@dataclass
+class DataCollectorSkippedEvent(SystemEvent):
+ """Data collector skipped event.
+
+ This event is published when particular data collector can
+ not provide data for the provided parameters.
+ """
+
+ data_collector: str
+ reason: str
+
+
+@dataclass
+class DataCollectionStageFinishedEvent(SystemEvent):
+ """Data collection stage finished.
+
+ This event is published when data collection stage finished.
+ """
+
+
+@dataclass
+class DataAnalysisStageStartedEvent(SystemEvent):
+ """Data analysis stage started.
+
+ This event is published when data analysis stage started.
+ """
+
+
+@dataclass
+class DataAnalysisStageFinishedEvent(SystemEvent):
+ """Data analysis stage finished.
+
+ This event is published when data analysis stage finished.
+ """
+
+
+@dataclass
+class AdviceStageStartedEvent(SystemEvent):
+ """Advace producing stage started.
+
+ This event is published when advice generation stage started.
+ """
+
+
+@dataclass
+class AdviceStageFinishedEvent(SystemEvent):
+ """Advace producing stage finished.
+
+ This event is published when advice generation stage finished.
+ """
+
+
+@dataclass
+class CollectedDataEvent(SystemEvent):
+ """Collected data event.
+
+ This event is published for every collected data item.
+
+ :param data_item: collected data item
+ """
+
+ data_item: DataItem
+
+
+@dataclass
+class AnalyzedDataEvent(SystemEvent):
+ """Analyzed data event.
+
+ This event is published for every analyzed data item.
+
+ :param data_item: analyzed data item
+ """
+
+ data_item: DataItem
+
+
+class EventHandler:
+ """Base class for the event handlers.
+
+ Each event handler should derive from this base class.
+ """
+
+ def handle_event(self, event: Event) -> None:
+ """Handle the event.
+
+ By default all published events are being passed to each
+ registered event handler. It is handler's responsibility
+ to filter events that it interested in.
+ """
+
+
+class DebugEventHandler(EventHandler):
+ """Event handler for debugging purposes.
+
+ This handler could print every published event to the
+ standard output.
+ """
+
+ def __init__(self, with_stacktrace: bool = False) -> None:
+ """Init event handler.
+
+ :param with_stacktrace: enable printing stacktrace of the
+ place where event publishing occurred.
+ """
+ self.with_stacktrace = with_stacktrace
+
+ def handle_event(self, event: Event) -> None:
+ """Handle event."""
+ print(f"Got event {event}")
+
+ if self.with_stacktrace:
+ traceback.print_stack()
+
+
+class EventDispatcherMetaclass(type):
+ """Metaclass for event dispatching.
+
+ It could be tedious to check type of the published event
+ inside event handler. Instead the following convention could be
+ established: if method name of the class starts with some
+ prefix then it is considered to be event handler of particular
+ type.
+
+ This metaclass goes through the list of class methods and
+ links all methods with the prefix "on_" to the common dispatcher
+ method.
+ """
+
+ def __new__(
+ cls,
+ clsname: str,
+ bases: Tuple,
+ namespace: Dict[str, Any],
+ event_handler_method_prefix: str = "on_",
+ ) -> Any:
+ """Create event dispatcher and link event handlers."""
+ new_class = super().__new__(cls, clsname, bases, namespace)
+
+ @singledispatchmethod
+ def dispatcher(_self: Any, _event: Event) -> Any:
+ """Event dispatcher."""
+
+ # get all class methods which starts with particular prefix
+ event_handler_methods = (
+ (item_name, item)
+ for item_name in dir(new_class)
+ if callable((item := getattr(new_class, item_name)))
+ and item_name.startswith(event_handler_method_prefix)
+ )
+
+ # link all collected event handlers to one dispatcher method
+ for method_name, method in event_handler_methods:
+ event_handler = dispatcher.register(method)
+ setattr(new_class, method_name, event_handler)
+
+ # override default handle_event method, replace it with the
+ # dispatcher
+ setattr(new_class, "handle_event", dispatcher)
+
+ return new_class
+
+
+class EventDispatcher(EventHandler, metaclass=EventDispatcherMetaclass):
+ """Event dispatcher."""
+
+
+class EventPublisher(ABC):
+ """Base class for the event publisher.
+
+ Event publisher is a intermidiate component between event emitter
+ and event consumer.
+ """
+
+ @abstractmethod
+ def register_event_handler(self, event_handler: EventHandler) -> None:
+ """Register event handler.
+
+ :param event_handler: instance of the event handler
+ """
+
+ def register_event_handlers(
+ self, event_handlers: Optional[List[EventHandler]]
+ ) -> None:
+ """Register event handlers.
+
+ Can be used for batch registration of the event handlers:
+
+ :param event_handlers: list of the event handler instances
+ """
+ if not event_handlers:
+ return
+
+ for handler in event_handlers:
+ self.register_event_handler(handler)
+
+ @abstractmethod
+ def publish_event(self, event: Event) -> None:
+ """Publish the event.
+
+ Deliver the event to the all registered event handlers.
+
+ :param event: event instance
+ """
+
+
+class DefaultEventPublisher(EventPublisher):
+ """Default event publishing implementation.
+
+ Simple implementation that maintains list of the registered event
+ handlers.
+ """
+
+ def __init__(self) -> None:
+ """Init the event publisher."""
+ self.handlers: List[EventHandler] = []
+
+ def register_event_handler(self, event_handler: EventHandler) -> None:
+ """Register the event handler.
+
+ :param event_handler: instance of the event handler
+ """
+ self.handlers.append(event_handler)
+
+ def publish_event(self, event: Event) -> None:
+ """Publish the event.
+
+ Publisher does not catch exceptions that could be raised by event handlers.
+ """
+ for handler in self.handlers:
+ handler.handle_event(event)
+
+
+@contextmanager
+def stage(
+ publisher: EventPublisher, events: Tuple[Event, Event]
+) -> Generator[None, None, None]:
+ """Generate events before and after stage.
+
+ This context manager could be used to mark start/finish
+ execution of a particular logical part of the workflow.
+ """
+ started, finished = events
+
+ publisher.publish_event(started)
+ yield
+ publisher.publish_event(finished)
+
+
+@contextmanager
+def action(
+ publisher: EventPublisher, action_type: str, params: Optional[Dict] = None
+) -> Generator[None, None, None]:
+ """Generate events before and after action."""
+ action_started = ActionStartedEvent(action_type, params)
+ action_finished = ActionFinishedEvent(action_started.event_id)
+
+ publisher.publish_event(action_started)
+ yield
+ publisher.publish_event(action_finished)
+
+
+class SystemEventsHandler(EventDispatcher):
+ """System events handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+
+ def on_data_collection_stage_finished(
+ self, event: DataCollectionStageFinishedEvent
+ ) -> None:
+ """Handle DataCollectionStageFinished event."""
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+
+ def on_data_analysis_stage_started(
+ self, event: DataAnalysisStageStartedEvent
+ ) -> None:
+ """Handle DataAnalysisStageStartedEvent event."""
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinishedEvent event."""
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinished event."""
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedData event."""
+
+ def on_analyzed_data(self, event: AnalyzedDataEvent) -> None:
+ """Handle AnalyzedData event."""
+
+ def on_action_started(self, event: ActionStartedEvent) -> None:
+ """Handle ActionStarted event."""
+
+ def on_action_finished(self, event: ActionFinishedEvent) -> None:
+ """Handle ActionFinished event."""
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
new file mode 100644
index 0000000..d10ea5d
--- /dev/null
+++ b/src/mlia/core/helpers.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for various helper classes."""
+# pylint: disable=no-self-use, unused-argument
+from typing import Any
+from typing import List
+
+
+class ActionResolver:
+ """Helper class for generating actions (e.g. commands with parameters)."""
+
+ def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ """Return action details for applying optimizations."""
+ return []
+
+ def supported_operators_info(self) -> List[str]:
+ """Return action details for generating supported ops report."""
+ return []
+
+ def check_performance(self) -> List[str]:
+ """Return action details for checking performance."""
+ return []
+
+ def check_operator_compatibility(self) -> List[str]:
+ """Return action details for checking op compatibility."""
+ return []
+
+ def operator_compatibility_details(self) -> List[str]:
+ """Return action details for getting more information about op compatibility."""
+ return []
+
+ def optimization_details(self) -> List[str]:
+ """Return action detail for getting information about optimizations."""
+ return []
+
+
+class APIActionResolver(ActionResolver):
+ """Helper class for the actions performed through API."""
diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py
new file mode 100644
index 0000000..ee03100
--- /dev/null
+++ b/src/mlia/core/mixins.py
@@ -0,0 +1,54 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Mixins module."""
+from typing import Any
+from typing import Optional
+
+from mlia.core.context import Context
+
+
+class ContextMixin:
+ """Mixin for injecting context object."""
+
+ context: Context
+
+ def set_context(self, context: Context) -> None:
+ """Context setter."""
+ self.context = context
+
+
+class ParameterResolverMixin:
+ """Mixin for parameter resolving."""
+
+ context: Context
+
+ def get_parameter(
+ self,
+ section: str,
+ name: str,
+ expected: bool = True,
+ expected_type: Optional[type] = None,
+ context: Optional[Context] = None,
+ ) -> Any:
+ """Get parameter value."""
+ ctx = context or self.context
+
+ if ctx.config_parameters is None:
+ raise Exception("Configuration parameters are not set")
+
+ section_params = ctx.config_parameters.get(section)
+ if section_params is None or not isinstance(section_params, dict):
+ raise Exception(
+ f"Parameter section {section} has wrong format, "
+ "expected to be a dictionary"
+ )
+
+ value = section_params.get(name)
+
+ if not value and expected:
+ raise Exception(f"Parameter {name} is not set")
+
+ if value and expected_type is not None and not isinstance(value, expected_type):
+ raise Exception(f"Parameter {name} expected to have type {expected_type}")
+
+ return value
diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py
new file mode 100644
index 0000000..5433d5c
--- /dev/null
+++ b/src/mlia/core/performance.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for performance estimation."""
+from abc import abstractmethod
+from typing import Callable
+from typing import Generic
+from typing import List
+from typing import TypeVar
+
+
+ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
+PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name
+
+
+class PerformanceEstimator(Generic[ModelType, PerfMetricsType]):
+ """Base class for the performance estimation."""
+
+ @abstractmethod
+ def estimate(self, model: ModelType) -> PerfMetricsType:
+ """Estimate performance."""
+
+
+def estimate_performance(
+ original_model: ModelType,
+ estimator: PerformanceEstimator[ModelType, PerfMetricsType],
+ model_transformations: List[Callable[[ModelType], ModelType]],
+) -> List[PerfMetricsType]:
+ """Estimate performance impact.
+
+ This function estimates performance impact on model performance after
+ applying provided transformations/optimizations.
+
+ :param original_model: object that represents a model, could be
+ instance of the model or path to the model. This depends on
+ provided performance estimator.
+ :param estimator: performance estimator
+ :param model_transformations: list of the callables each of those
+ returns object that represents optimized model
+ """
+ original_metrics = estimator.estimate(original_model)
+
+ optimized_metrics = [
+ estimator.estimate(transform(original_model))
+ for transform in model_transformations
+ ]
+
+ return [original_metrics, *optimized_metrics]
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
new file mode 100644
index 0000000..1b75bb4
--- /dev/null
+++ b/src/mlia/core/reporting.py
@@ -0,0 +1,762 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reporting module."""
+import csv
+import json
+import logging
+from abc import ABC
+from abc import abstractmethod
+from collections import defaultdict
+from contextlib import contextmanager
+from contextlib import ExitStack
+from dataclasses import dataclass
+from functools import partial
+from io import TextIOWrapper
+from pathlib import Path
+from textwrap import fill
+from textwrap import indent
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+
+from mlia.core._typing import FileLike
+from mlia.core._typing import OutputFormat
+from mlia.core._typing import PathOrFileLike
+from mlia.utils.console import apply_style
+from mlia.utils.console import produce_table
+from mlia.utils.logging import LoggerWriter
+from mlia.utils.types import is_list_of
+
+logger = logging.getLogger(__name__)
+
+
+class Report(ABC):
+ """Abstract class for the report."""
+
+ @abstractmethod
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format."""
+
+ @abstractmethod
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format."""
+
+ @abstractmethod
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+
+
+class ReportItem:
+ """Item of the report."""
+
+ def __init__(
+ self,
+ name: str,
+ alias: Optional[str] = None,
+ value: Optional[Union[str, int, "Cell"]] = None,
+ nested_items: Optional[List["ReportItem"]] = None,
+ ) -> None:
+ """Init the report item."""
+ self.name = name
+ self.alias = alias
+ self.value = value
+ self.nested_items = nested_items or []
+
+ @property
+ def compound(self) -> bool:
+ """Return true if item has nested items."""
+ return self.nested_items is not None and len(self.nested_items) > 0
+
+ @property
+ def raw_value(self) -> Any:
+ """Get actual item value."""
+ val = self.value
+ if isinstance(val, Cell):
+ return val.value
+
+ return val
+
+
+@dataclass
+class Format:
+ """Column or cell format.
+
+ Format could be applied either to a column or an individual cell.
+
+ :param wrap_width: width of the wrapped text value
+ :param str_fmt: string format to be applied to the value
+ :param style: text style
+ """
+
+ wrap_width: Optional[int] = None
+ str_fmt: Optional[Union[str, Callable[[Any], str]]] = None
+ style: Optional[str] = None
+
+
+@dataclass
+class Cell:
+ """Cell definition.
+
+ This a wrapper class for a particular value in the table. Could be used
+ for applying specific format to this value.
+ """
+
+ value: Any
+ fmt: Optional[Format] = None
+
+ def _apply_style(self, value: str) -> str:
+ """Apply style to the value."""
+ if self.fmt and self.fmt.style:
+ value = apply_style(value, self.fmt.style)
+
+ return value
+
+ def _get_value(self) -> str:
+ """Return cell value."""
+ if self.fmt:
+ if isinstance(self.fmt.str_fmt, str):
+ return "{:{fmt}}".format(self.value, fmt=self.fmt.str_fmt)
+
+ if callable(self.fmt.str_fmt):
+ return self.fmt.str_fmt(self.value)
+
+ return str(self.value)
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ val = self._get_value()
+ return self._apply_style(val)
+
+ def to_csv(self) -> Any:
+ """Cell definition for csv."""
+ return self.value
+
+ def to_json(self) -> Any:
+ """Cell definition for json."""
+ return self.value
+
+
+class CountAwareCell(Cell):
+ """Count aware cell."""
+
+ def __init__(
+ self,
+ value: Optional[Union[int, float]],
+ singular: str,
+ plural: str,
+ format_string: str = ",d",
+ ):
+ """Init cell instance."""
+ self.unit = singular if value == 1 else plural
+
+ def format_value(val: Optional[Union[int, float]]) -> str:
+ """Provide string representation for the value."""
+ if val is None:
+ return ""
+
+ if val == 1:
+ return f"1 {singular}"
+
+ return f"{val:{format_string}} {plural}"
+
+ super().__init__(value, Format(str_fmt=format_value))
+
+ def to_csv(self) -> Any:
+ """Cell definition for csv."""
+ return {"value": self.value, "unit": self.unit}
+
+ def to_json(self) -> Any:
+ """Cell definition for json."""
+ return {"value": self.value, "unit": self.unit}
+
+
+class BytesCell(CountAwareCell):
+ """Cell that represents memory size."""
+
+ def __init__(self, value: Optional[int]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "byte", "bytes")
+
+
+class CyclesCell(CountAwareCell):
+ """Cell that represents cycles."""
+
+ def __init__(self, value: Optional[Union[int, float]]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "cycle", "cycles", ",.0f")
+
+
+class ClockCell(CountAwareCell):
+ """Cell that represents clock value."""
+
+ def __init__(self, value: Optional[Union[int, float]]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "Hz", "Hz", ",.0f")
+
+
+class Column:
+ """Column definition."""
+
+ def __init__(
+ self,
+ header: str,
+ alias: Optional[str] = None,
+ fmt: Optional[Format] = None,
+ only_for: Optional[List[str]] = None,
+ ) -> None:
+ """Init column definition.
+
+ :param header: column's header
+ :param alias: columns's alias, could be used as column's name
+ :param fmt: format that will be applied for all column's values
+ :param only_for: list of the formats where this column should be
+ represented. May be used to differentiate data representation in
+ different formats
+ """
+ self.header = header
+ self.alias = alias
+ self.fmt = fmt
+ self.only_for = only_for
+
+ def supports_format(self, fmt: str) -> bool:
+ """Return true if column should be shown."""
+ return not self.only_for or fmt in self.only_for
+
+
+class NestedReport(Report):
+ """Report with nested items."""
+
+ def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None:
+ """Init nested report."""
+ self.name = name
+ self.alias = alias
+ self.items = items
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format."""
+ result = {}
+
+ def collect_item_values(
+ item: ReportItem,
+ _parent: Optional[ReportItem],
+ _prev: Optional[ReportItem],
+ _level: int,
+ ) -> None:
+ """Collect item values into a dictionary.."""
+ if item.value is None:
+ return
+
+ if not isinstance(item.value, Cell):
+ result[item.alias] = item.raw_value
+ return
+
+ csv_value = item.value.to_csv()
+ if isinstance(csv_value, dict):
+ csv_value = {
+ f"{item.alias}_{key}": value for key, value in csv_value.items()
+ }
+ else:
+ csv_value = {item.alias: csv_value}
+
+ result.update(csv_value)
+
+ self._traverse(self.items, collect_item_values)
+
+ # make list out of the result dictionary
+ # first element - keys of the dictionary as headers
+ # second element - list of the dictionary values
+ return list(zip(*result.items()))
+
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format."""
+ per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict)
+ result = per_parent[None]
+
+ def collect_as_dicts(
+ item: ReportItem,
+ parent: Optional[ReportItem],
+ _prev: Optional[ReportItem],
+ _level: int,
+ ) -> None:
+ """Collect item values as nested dictionaries."""
+ parent_dict = per_parent[parent]
+
+ if item.compound:
+ item_dict = per_parent[item]
+ parent_dict[item.alias] = item_dict
+ else:
+ out_dis = (
+ item.value.to_json()
+ if isinstance(item.value, Cell)
+ else item.raw_value
+ )
+ parent_dict[item.alias] = out_dis
+
+ self._traverse(self.items, collect_as_dicts)
+
+ return {self.alias: result}
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+ header = f"{self.name}:\n"
+ processed_items = []
+
+ def convert_to_text(
+ item: ReportItem,
+ _parent: Optional[ReportItem],
+ prev: Optional[ReportItem],
+ level: int,
+ ) -> None:
+ """Convert item to text representation."""
+ if level >= 1 and prev is not None and (item.compound or prev.compound):
+ processed_items.append("")
+
+ val = self._item_value(item, level)
+ processed_items.append(val)
+
+ self._traverse(self.items, convert_to_text)
+ body = "\n".join(processed_items)
+
+ return header + body
+
+ @staticmethod
+ def _item_value(
+ item: ReportItem, level: int, tab_size: int = 2, column_width: int = 35
+ ) -> str:
+ """Get report item value."""
+ shift = " " * tab_size * level
+ if item.value is None:
+ return f"{shift}{item.name}:"
+
+ col1 = f"{shift}{item.name}".ljust(column_width)
+ col2 = f"{item.value}".rjust(column_width)
+
+ return col1 + col2
+
+ def _traverse(
+ self,
+ items: List[ReportItem],
+ visit_item: Callable[
+ [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None
+ ],
+ level: int = 1,
+ parent: Optional[ReportItem] = None,
+ ) -> None:
+ """Traverse through items."""
+ prev = None
+ for item in items:
+ visit_item(item, parent, prev, level)
+
+ self._traverse(item.nested_items, visit_item, level + 1, item)
+ prev = item
+
+
+class Table(Report):
+ """Table definition.
+
+ This class could be used for representing tabular data.
+ """
+
+ def __init__(
+ self,
+ columns: List[Column],
+ rows: Collection,
+ name: str,
+ alias: Optional[str] = None,
+ notes: Optional[str] = None,
+ ) -> None:
+ """Init table definition.
+
+ :param columns: list of the table's columns
+ :param rows: list of the table's rows
+ :param name: name of the table
+ :param alias: alias for the table
+ """
+ self.columns = columns
+ self.rows = rows
+ self.name = name
+ self.alias = alias
+ self.notes = notes
+
+ def to_json(self, **kwargs: Any) -> Iterable:
+ """Convert table to dict object."""
+
+ def item_to_json(item: Any) -> Any:
+ value = item
+ if isinstance(item, Cell):
+ value = item.value
+
+ if isinstance(value, Table):
+ return value.to_json()
+
+ return value
+
+ json_data = [
+ {
+ col.alias or col.header: item_to_json(item)
+ for (item, col) in zip(row, self.columns)
+ if col.supports_format("json")
+ }
+ for row in self.rows
+ ]
+
+ if not self.alias:
+ return json_data
+
+ return {self.alias: json_data}
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Produce report in human readable format."""
+ nested = kwargs.get("nested", False)
+ show_headers = kwargs.get("show_headers", True)
+ show_title = kwargs.get("show_title", True)
+ table_style = kwargs.get("table_style", "default")
+ space = kwargs.get("space", False)
+
+ headers = (
+ [] if (nested or not show_headers) else [c.header for c in self.columns]
+ )
+
+ def item_to_plain_text(item: Any, col: Column) -> str:
+ """Convert item to text."""
+ if isinstance(item, Table):
+ return item.to_plain_text(nested=True, **kwargs)
+
+ if is_list_of(item, str):
+ as_text = "\n".join(item)
+ else:
+ as_text = str(item)
+
+ if col.fmt and col.fmt.wrap_width:
+ as_text = fill(as_text, col.fmt.wrap_width)
+
+ return as_text
+
+ title = ""
+ if show_title and not nested:
+ title = f"{self.name}:\n"
+
+ if space in (True, "top"):
+ title = "\n" + title
+
+ footer = ""
+ if space in (True, "bottom"):
+ footer = "\n"
+ if self.notes:
+ footer = "\n" + self.notes
+
+ formatted_rows = (
+ (
+ item_to_plain_text(item, col)
+ for item, col in zip(row, self.columns)
+ if col.supports_format("plain_text")
+ )
+ for row in self.rows
+ )
+
+ if space == "between":
+ formatted_table = "\n\n".join(
+ produce_table([row], table_style=table_style) for row in formatted_rows
+ )
+ else:
+ formatted_table = produce_table(
+ formatted_rows,
+ headers=headers,
+ table_style="nested" if nested else table_style,
+ )
+
+ return title + formatted_table + footer
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert table to csv format."""
+ headers = [[c.header for c in self.columns if c.supports_format("csv")]]
+
+ def item_data(item: Any) -> Any:
+ if isinstance(item, Cell):
+ return item.value
+
+ if isinstance(item, Table):
+ return ";".join(
+ str(item_data(cell)) for row in item.rows for cell in row
+ )
+
+ return item
+
+ rows = [
+ [
+ item_data(item)
+ for (item, col) in zip(row, self.columns)
+ if col.supports_format("csv")
+ ]
+ for row in self.rows
+ ]
+
+ return headers + rows
+
+
+class SingleRow(Table):
+ """Table with a single row."""
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Produce report in human readable format."""
+ if len(self.rows) != 1:
+ raise Exception("Table should have only one row")
+
+ items = "\n".join(
+ column.header.ljust(35) + str(item).rjust(25)
+ for row in self.rows
+ for item, column in zip(row, self.columns)
+ if column.supports_format("plain_text")
+ )
+
+ return "\n".join([f"{self.name}:", indent(items, " ")])
+
+
+class CompoundReport(Report):
+ """Compound report.
+
+ This class could be used for producing multiple reports at once.
+ """
+
+ def __init__(self, reports: List[Report]) -> None:
+ """Init compound report instance."""
+ self.reports = reports
+
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format.
+
+ Method attempts to create compound dictionary based on provided
+ parts.
+ """
+ result: Dict[str, Any] = {}
+ for item in self.reports:
+ result.update(item.to_json(**kwargs))
+
+ return result
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format.
+
+ CSV format does support only one table. In order to be able to export
+ multiply tables they should be merged before that. This method tries to
+ do next:
+
+ - if all tables have the same length then just concatenate them
+ - if one table has many rows and other just one (two with headers), then
+ for each row in table with many rows duplicate values from other tables
+ """
+ csv_data = [item.to_csv() for item in self.reports]
+ lengths = [len(csv_item_data) for csv_item_data in csv_data]
+
+ same_length = len(set(lengths)) == 1
+ if same_length:
+ # all lists are of the same length, merge them into one
+ return [[cell for item in row for cell in item] for row in zip(*csv_data)]
+
+ main_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) > 2]
+ one_main_obj = len(main_obj_indexes) == 1
+
+ reference_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) == 2]
+ other_only_ref_objs = len(reference_obj_indexes) == len(csv_data) - 1
+
+ if one_main_obj and other_only_ref_objs:
+ main_obj = csv_data[main_obj_indexes[0]]
+ return [
+ item
+ + [
+ ref_item
+ for ref_table_index in reference_obj_indexes
+ for ref_item in csv_data[ref_table_index][0 if i == 0 else 1]
+ ]
+ for i, item in enumerate(main_obj)
+ ]
+
+ # write tables one after another if there is no other options
+ return [row for item in csv_data for row in item]
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+ return "\n".join(item.to_plain_text(**kwargs) for item in self.reports)
+
+
+class CompoundFormatter:
+ """Compound data formatter."""
+
+ def __init__(self, formatters: List[Callable]) -> None:
+ """Init compound formatter."""
+ self.formatters = formatters
+
+ def __call__(self, data: Any) -> Report:
+ """Produce report."""
+ reports = [formatter(item) for item, formatter in zip(data, self.formatters)]
+ return CompoundReport(reports)
+
+
+class CustomJSONEncoder(json.JSONEncoder):
+ """Custom JSON encoder."""
+
+ def default(self, o: Any) -> Any:
+ """Support numpy types."""
+ if isinstance(o, np.integer):
+ return int(o)
+
+ if isinstance(o, np.floating):
+ return float(o)
+
+ return json.JSONEncoder.default(self, o)
+
+
+def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in json format."""
+ json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder)
+ print(json_str, file=output)
+
+
+def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in text format."""
+ print(report.to_plain_text(**kwargs), file=output)
+
+
+def csv_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in csv format."""
+ csv_writer = csv.writer(output)
+ csv_writer.writerows(report.to_csv(**kwargs))
+
+
+def produce_report(
+ data: Any,
+ formatter: Callable[[Any], Report],
+ fmt: OutputFormat = "plain_text",
+ output: Optional[PathOrFileLike] = None,
+ **kwargs: Any,
+) -> None:
+ """Produce report based on provided data."""
+ # check if provided format value is supported
+ formats = {"json": json_reporter, "plain_text": text_reporter, "csv": csv_reporter}
+ if fmt not in formats:
+ raise Exception(f"Unknown format {fmt}")
+
+ if output is None:
+ output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO))
+
+ with ExitStack() as exit_stack:
+ if isinstance(output, (str, Path)):
+ # open file and add it to the ExitStack context manager
+ # in that case it will be automatically closed
+ stream = exit_stack.enter_context(open(output, "w", encoding="utf-8"))
+ else:
+ stream = cast(TextIOWrapper, output)
+
+ # convert data into serializable form
+ formatted_data = formatter(data)
+ # find handler for the format
+ format_handler = formats[fmt]
+ # produce report in requested format
+ format_handler(formatted_data, stream, **kwargs)
+
+
+class Reporter:
+ """Reporter class."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ output_format: OutputFormat = "plain_text",
+ print_as_submitted: bool = True,
+ ) -> None:
+ """Init reporter instance."""
+ self.formatter_resolver = formatter_resolver
+ self.output_format = output_format
+ self.print_as_submitted = print_as_submitted
+
+ self.data: List[Tuple[Any, Callable[[Any], Report]]] = []
+ self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = []
+
+ def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
+ """Submit data for the report."""
+ if self.print_as_submitted and not delay_print:
+ produce_report(
+ data_item,
+ self.formatter_resolver(data_item),
+ fmt="plain_text",
+ **kwargs,
+ )
+
+ formatter = _apply_format_parameters(
+ self.formatter_resolver(data_item), self.output_format, **kwargs
+ )
+ self.data.append((data_item, formatter))
+
+ if delay_print:
+ self.delayed.append((data_item, formatter))
+
+ def print_delayed(self) -> None:
+ """Print delayed reports."""
+ if not self.delayed:
+ return
+
+ data, formatters = zip(*self.delayed)
+ produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ fmt="plain_text",
+ )
+ self.delayed = []
+
+ def generate_report(self, output: Optional[PathOrFileLike]) -> None:
+ """Generate report."""
+ already_printed = (
+ self.print_as_submitted
+ and self.output_format == "plain_text"
+ and output is None
+ )
+ if not self.data or already_printed:
+ return
+
+ data, formatters = zip(*self.data)
+ produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ fmt=self.output_format,
+ output=output,
+ )
+
+
+@contextmanager
+def get_reporter(
+ output_format: OutputFormat,
+ output: Optional[PathOrFileLike],
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+) -> Generator[Reporter, None, None]:
+ """Get reporter and generate report."""
+ reporter = Reporter(formatter_resolver, output_format)
+
+ yield reporter
+
+ reporter.generate_report(output)
+
+
+def _apply_format_parameters(
+ formatter: Callable[[Any], Report], output_format: OutputFormat, **kwargs: Any
+) -> Callable[[Any], Report]:
+ """Wrap report method."""
+
+ def wrapper(data: Any) -> Report:
+ report = formatter(data)
+ method_name = f"to_{output_format}"
+ method = getattr(report, method_name)
+ setattr(report, method_name, partial(method, **kwargs))
+
+ return report
+
+ return wrapper
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
new file mode 100644
index 0000000..0245087
--- /dev/null
+++ b/src/mlia/core/workflow.py
@@ -0,0 +1,216 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for executors.
+
+This module contains implementation of the workflow
+executors.
+"""
+import itertools
+from abc import ABC
+from abc import abstractmethod
+from functools import wraps
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.common import DataItem
+from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import Event
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import stage
+from mlia.core.mixins import ContextMixin
+
+
+class WorkflowExecutor(ABC):
+ """Base workflow executor."""
+
+ @abstractmethod
+ def run(self) -> None:
+ """Run the module."""
+
+
+STAGE_COLLECTION = (
+ DataCollectionStageStartedEvent(),
+ DataCollectionStageFinishedEvent(),
+)
+STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEvent())
+STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent())
+
+
+def on_stage(stage_events: Tuple[Event, Event]) -> Callable:
+ """Mark start/finish of the stage with appropriate events."""
+
+ def wrapper(method: Callable) -> Callable:
+ """Wrap method."""
+
+ @wraps(method)
+ def publish_events(self: Any, *args: Any, **kwargs: Any) -> Any:
+ """Publish events before and after execution."""
+ with stage(self.context.event_publisher, stage_events):
+ return method(self, *args, **kwargs)
+
+ return publish_events
+
+ return wrapper
+
+
+class DefaultWorkflowExecutor(WorkflowExecutor):
+ """Default module executor.
+
+ This is a default implementation of the workflow executor.
+ All components are launched sequentually in the same process.
+ """
+
+ def __init__(
+ self,
+ context: Context,
+ collectors: Sequence[DataCollector],
+ analyzers: Sequence[DataAnalyzer],
+ producers: Sequence[AdviceProducer],
+ before_start_events: Optional[Sequence[Event]] = None,
+ ):
+ """Init default workflow executor.
+
+ :param context: Context instance
+ :param collectors: List of the data collectors
+ :param analyzers: List of the data analyzers
+ :param producers: List of the advice producers
+ :param before_start_events: Optional list of the custom events that
+ should be published before start of the worfkow execution.
+ """
+ self.context = context
+ self.collectors = collectors
+ self.analyzers = analyzers
+ self.producers = producers
+ self.before_start_events = before_start_events
+
+ def run(self) -> None:
+ """Run the workflow."""
+ self.inject_context()
+ self.context.register_event_handlers()
+
+ try:
+ self.publish(ExecutionStartedEvent())
+
+ self.before_start()
+
+ collected_data = self.collect_data()
+ analyzed_data = self.analyze_data(collected_data)
+
+ self.produce_advice(analyzed_data)
+ except Exception as err: # pylint: disable=broad-except
+ self.publish(ExecutionFailedEvent(err))
+ else:
+ self.publish(ExecutionFinishedEvent())
+
+ def before_start(self) -> None:
+ """Run actions before start of the workflow execution."""
+ events = self.before_start_events or []
+ for event in events:
+ self.publish(event)
+
+ @on_stage(STAGE_COLLECTION)
+ def collect_data(self) -> List[DataItem]:
+ """Collect data.
+
+ Run each of data collector components and return list of
+ the collected data items.
+ """
+ collected_data = []
+ for collector in self.collectors:
+ try:
+ if (data_item := collector.collect_data()) is not None:
+ collected_data.append(data_item)
+ self.publish(CollectedDataEvent(data_item))
+ except FunctionalityNotSupportedError as err:
+ self.publish(DataCollectorSkippedEvent(collector.name(), str(err)))
+
+ return collected_data
+
+ @on_stage(STAGE_ANALYSIS)
+ def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]:
+ """Analyze data.
+
+ Pass each collected data item into each data analyzer and
+ return analyzed data.
+
+ :param collected_data: list of collected data items
+ """
+ analyzed_data = []
+ for analyzer in self.analyzers:
+ for item in collected_data:
+ analyzer.analyze_data(item)
+
+ for data_item in analyzer.get_analyzed_data():
+ analyzed_data.append(data_item)
+
+ self.publish(AnalyzedDataEvent(data_item))
+ return analyzed_data
+
+ @on_stage(STAGE_ADVICE)
+ def produce_advice(self, analyzed_data: List[DataItem]) -> None:
+ """Produce advice.
+
+ Pass each analyzed data item into each advice producer and
+ publish generated advice.
+
+ :param analyzed_data: list of analyzed data items
+ """
+ for producer in self.producers:
+ for data_item in analyzed_data:
+ producer.produce_advice(data_item)
+
+ advice = producer.get_advice()
+ if isinstance(advice, Advice):
+ advice = [advice]
+
+ for item in advice:
+ self.publish(AdviceEvent(item))
+
+ def inject_context(self) -> None:
+ """Inject context object into components.
+
+ Inject context object into components that supports context
+ injection.
+ """
+ context_aware_components = (
+ comp
+ for comp in itertools.chain(
+ self.collectors,
+ self.analyzers,
+ self.producers,
+ )
+ if isinstance(comp, ContextMixin)
+ )
+
+ for component in context_aware_components:
+ component.set_context(self.context)
+
+ def publish(self, event: Event) -> None:
+ """Publish event.
+
+ Helper method for event publising.
+
+ :param event: event instance
+ """
+ self.context.event_publisher.publish_event(event)
diff --git a/src/mlia/devices/__init__.py b/src/mlia/devices/__init__.py
new file mode 100644
index 0000000..d533f4a
--- /dev/null
+++ b/src/mlia/devices/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Devices module."""
diff --git a/src/mlia/devices/config.py b/src/mlia/devices/config.py
new file mode 100644
index 0000000..7ab6b43
--- /dev/null
+++ b/src/mlia/devices/config.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""IP configuration module."""
+
+
+class IPConfiguration: # pylint: disable=too-few-public-methods
+ """Base class for IP configuration."""
+
+ def __init__(self, target: str) -> None:
+ """Init IP configuration instance."""
+ self.target = target
diff --git a/src/mlia/devices/ethosu/__init__.py b/src/mlia/devices/ethosu/__init__.py
new file mode 100644
index 0000000..73925e1
--- /dev/null
+++ b/src/mlia/devices/ethosu/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U devices module."""
diff --git a/src/mlia/devices/ethosu/advice_generation.py b/src/mlia/devices/ethosu/advice_generation.py
new file mode 100644
index 0000000..7a818c9
--- /dev/null
+++ b/src/mlia/devices/ethosu/advice_generation.py
@@ -0,0 +1,209 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U advice generation."""
+from functools import singledispatchmethod
+from typing import List
+from typing import Union
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import ContextAwareAdviceProducer
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+class EthosUAdviceProducer(FactBasedAdviceProducer):
+ """Ethos-U advice producer."""
+
+ @singledispatchmethod
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Produce advice."""
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ def handle_cpu_only_ops(self, data_item: HasCPUOnlyOperators) -> None:
+ """Advice for CPU only operators."""
+ cpu_only_ops = ",".join(sorted(set(data_item.cpu_only_ops)))
+ cpu_only_ops_num = len(data_item.cpu_only_ops)
+
+ self.add_advice(
+ [
+ f"You have at least {cpu_only_ops_num} "
+ f"operator{'s' if cpu_only_ops_num > 1 else ''} that is CPU "
+ f"only: {cpu_only_ops}.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ ]
+ + self.context.action_resolver.supported_operators_info()
+ )
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ def handle_unsupported_operators(
+ self, data_item: HasUnsupportedOnNPUOperators
+ ) -> None:
+ """Advice for the unsupported operators."""
+ self.add_advice(
+ [
+ f"You have {data_item.npu_unsupported_ratio*100:.0f}% of operators "
+ "that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ def handle_all_operators_supported(
+ self, _data_item: AllOperatorsSupportedOnNPU
+ ) -> None:
+ """Advice if all operators supported."""
+ self.add_advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU."
+ ]
+ + self.context.action_resolver.check_performance()
+ )
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPTIMIZATION, AdviceCategory.ALL)
+ def handle_optimization_results(self, data_item: OptimizationResults) -> None:
+ """Advice based on optimization results."""
+ if not data_item.diffs or len(data_item.diffs) != 1:
+ return
+
+ optim_details = data_item.diffs[0]
+ metrics = [
+ (metric_name, optim_details.opt_diffs[metric_key])
+ for (metric_name, metric_key) in (
+ ("DRAM used (KB)", "dram"),
+ ("SRAM used (KB)", "sram"),
+ ("On chip flash used (KB)", "on_chip_flash"),
+ ("Off chip flash used (KB)", "off_chip_flash"),
+ ("NPU total cycles", "npu_total_cycles"),
+ )
+ if metric_key in optim_details.opt_diffs
+ and not optim_details.opt_diffs[metric_key].same
+ ]
+
+ improved = [
+ f"- You have achieved {abs(metric_value.diff):.2f}% performance "
+ f"improvement in {metric_name}"
+ for metric_name, metric_value in metrics
+ if metric_value.improved
+ ]
+
+ degraded = [
+ f"- {metric_name} have degraded by {abs(metric_value.diff):.2f}%"
+ for metric_name, metric_value in metrics
+ if metric_value.degraded
+ ]
+
+ opts = ", ".join(str(s) for s in optim_details.opt_type)
+ messages = [f"With the selected optimization ({opts})", *improved, *degraded]
+
+ if improved:
+ if next_optimization_target := self.get_next_optimization_targets(
+ optim_details.opt_type
+ ):
+ next_optimization_target_as_str = " and/or ".join(
+ str(item) for item in next_optimization_target
+ )
+
+ messages.append(
+ "You can try to push the optimization target higher "
+ f"(e.g. {next_optimization_target_as_str}) "
+ "to check if those results can be further improved."
+ )
+ messages += self.context.action_resolver.apply_optimizations(
+ opt_settings=next_optimization_target
+ )
+
+ elif degraded:
+ messages.append(
+ "The performance seems to have degraded after "
+ "applying the selected optimizations, "
+ "try exploring different optimization types/targets."
+ )
+
+ self.add_advice(messages)
+
+ self.add_advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ )
+
+ @staticmethod
+ def get_next_optimization_targets(
+ opt_type: List[OptimizationSettings],
+ ) -> List[OptimizationSettings]:
+ """Get next optimization targets."""
+ next_targets = (item.next_target() for item in opt_type)
+
+ # filter out targets that have not been changed
+ valid_targets = [
+ next_
+ for next_, old in zip(next_targets, opt_type)
+ if (
+ old.optimization_type == "pruning"
+ and old.optimization_target < next_.optimization_target
+ )
+ or (
+ old.optimization_type == "clustering"
+ and old.optimization_target > next_.optimization_target
+ )
+ ]
+ return valid_targets
+
+
+class EthosUStaticAdviceProducer(ContextAwareAdviceProducer):
+ """Advice producer that not depends on input data."""
+
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Do not process passed data items."""
+
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Return predefined advice based on category."""
+ if self.context.advice_category is None:
+ return []
+
+ advice_per_category = {
+ AdviceCategory.PERFORMANCE: [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ ]
+ + self.context.action_resolver.check_operator_compatibility()
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model."
+ ]
+ + self.context.action_resolver.apply_optimizations()
+ ),
+ ],
+ AdviceCategory.OPTIMIZATION: [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ ]
+ + self.context.action_resolver.operator_compatibility_details()
+ )
+ ],
+ }
+
+ return advice_per_category.get(self.context.advice_category, [])
diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py
new file mode 100644
index 0000000..802826b
--- /dev/null
+++ b/src/mlia/devices/ethosu/advisor.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U MLIA module."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.mixins import ParameterResolverMixin
+from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.core.workflow import WorkflowExecutor
+from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
+from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.config import get_target
+from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer
+from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
+from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
+from mlia.devices.ethosu.data_collection import EthosUPerformance
+from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
+
+
+class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
+ """Ethos-U Inference Advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "ethos_u_inference_advisor"
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor execution."""
+ model = self._get_model(context)
+ device = self._get_device(context)
+ backends = self._get_backends(context)
+
+ collectors = self._get_collectors(context, model, device, backends)
+ analyzers = self._get_analyzers()
+ producers = self._get_advice_producers()
+
+ return DefaultWorkflowExecutor(
+ context,
+ collectors,
+ analyzers,
+ producers,
+ before_start_events=[
+ EthosUAdvisorStartedEvent(device=device, model=model),
+ ],
+ )
+
+ def _get_collectors(
+ self,
+ context: Context,
+ model: Path,
+ device: EthosUConfiguration,
+ backends: Optional[List[str]],
+ ) -> List[DataCollector]:
+ """Get collectors."""
+ collectors: List[DataCollector] = []
+
+ if context.any_category_enabled(
+ AdviceCategory.OPERATORS,
+ AdviceCategory.ALL,
+ ):
+ collectors.append(EthosUOperatorCompatibility(model, device))
+
+ if context.category_enabled(AdviceCategory.PERFORMANCE):
+ collectors.append(EthosUPerformance(model, device, backends))
+
+ if context.any_category_enabled(
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.ALL,
+ ):
+ optimization_settings = self._get_optimization_settings(context)
+ collectors.append(
+ EthosUOptimizationPerformance(
+ model, device, optimization_settings, backends
+ )
+ )
+
+ return collectors
+
+ @staticmethod
+ def _get_analyzers() -> List[DataAnalyzer]:
+ """Return data analyzers."""
+ return [
+ EthosUDataAnalyzer(),
+ ]
+
+ @staticmethod
+ def _get_advice_producers() -> List[AdviceProducer]:
+ """Return advice producers."""
+ return [
+ EthosUAdviceProducer(),
+ EthosUStaticAdviceProducer(),
+ ]
+
+ def _get_device(self, context: Context) -> EthosUConfiguration:
+ """Get device."""
+ device_params = self.get_parameter(
+ self.name(),
+ "device",
+ expected_type=dict,
+ context=context,
+ )
+
+ try:
+ target_profile = device_params["target_profile"]
+ except KeyError as err:
+ raise Exception("Unable to get device details") from err
+
+ return get_target(target_profile)
+
+ def _get_model(self, context: Context) -> Path:
+ """Get path to the model."""
+ model_param = self.get_parameter(
+ self.name(),
+ "model",
+ expected_type=str,
+ context=context,
+ )
+
+ if not (model := Path(model_param)).exists():
+ raise Exception(f"Path {model} does not exist")
+
+ return model
+
+ def _get_optimization_settings(self, context: Context) -> List[List[dict]]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ EthosUOptimizationPerformance.name(),
+ "optimizations",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
+ def _get_backends(self, context: Context) -> Optional[List[str]]:
+ """Get list of backends."""
+ return self.get_parameter( # type: ignore
+ self.name(),
+ "backends",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
diff --git a/src/mlia/devices/ethosu/config.py b/src/mlia/devices/ethosu/config.py
new file mode 100644
index 0000000..cecbb27
--- /dev/null
+++ b/src/mlia/devices/ethosu/config.py
@@ -0,0 +1,89 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U configuration."""
+import logging
+from typing import Any
+from typing import Dict
+
+from mlia.devices.config import IPConfiguration
+from mlia.tools.vela_wrapper import resolve_compiler_config
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.filesystem import get_profile
+from mlia.utils.filesystem import get_vela_config
+
+
+logger = logging.getLogger(__name__)
+
+
+class EthosUConfiguration(IPConfiguration):
+ """Ethos-U configuration."""
+
+ def __init__(self, target_profile: str) -> None:
+ """Init Ethos-U target configuration."""
+ target_data = get_profile(target_profile)
+ _check_target_data_complete(target_data)
+
+ target = target_data["target"]
+ super().__init__(target)
+
+ mac = target_data["mac"]
+ _check_device_options_valid(target, mac)
+
+ self.mac = mac
+ self.compiler_options = VelaCompilerOptions(
+ system_config=target_data["system_config"],
+ memory_mode=target_data["memory_mode"],
+ config_files=str(get_vela_config()),
+ accelerator_config=f"{self.target}-{mac}", # type: ignore
+ )
+
+ @property
+ def resolved_compiler_config(self) -> Dict[str, Any]:
+ """Resolve compiler configuration."""
+ return resolve_compiler_config(self.compiler_options)
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ return (
+ f"Ethos-U target={self.target} "
+ f"mac={self.mac} "
+ f"compiler_options={self.compiler_options}"
+ )
+
+ def __repr__(self) -> str:
+ """Return string representation."""
+ return f"<Ethos-U configuration target={self.target}>"
+
+
+def get_target(target_profile: str) -> EthosUConfiguration:
+ """Get target instance based on provided params."""
+ if not target_profile:
+ raise Exception("No target profile given")
+
+ return EthosUConfiguration(target_profile)
+
+
+def _check_target_data_complete(target_data: Dict[str, Any]) -> None:
+ """Check if profile contains all needed data."""
+ mandatory_keys = {"target", "mac", "system_config", "memory_mode"}
+ missing_keys = sorted(mandatory_keys - target_data.keys())
+
+ if missing_keys:
+ raise Exception(f"Mandatory fields missing from target profile: {missing_keys}")
+
+
+def _check_device_options_valid(target: str, mac: int) -> None:
+ """Check if mac is valid for selected device."""
+ target_mac_ranges = {
+ "ethos-u55": [32, 64, 128, 256],
+ "ethos-u65": [256, 512],
+ }
+
+ if target not in target_mac_ranges:
+ raise Exception(f"Unsupported target: {target}")
+
+ target_mac_range = target_mac_ranges[target]
+ if mac not in target_mac_range:
+ raise Exception(
+ f"Mac value for selected device should be in {target_mac_range}"
+ )
diff --git a/src/mlia/devices/ethosu/data_analysis.py b/src/mlia/devices/ethosu/data_analysis.py
new file mode 100644
index 0000000..9ed32ff
--- /dev/null
+++ b/src/mlia/devices/ethosu/data_analysis.py
@@ -0,0 +1,154 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U data analysis module."""
+from dataclasses import dataclass
+from functools import singledispatchmethod
+from typing import Dict
+from typing import List
+from typing import Union
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import Operators
+
+
+@dataclass
+class HasCPUOnlyOperators(Fact):
+ """Model has CPU only operators."""
+
+ cpu_only_ops: List[str]
+
+
+@dataclass
+class HasUnsupportedOnNPUOperators(Fact):
+ """Model has unsupported on NPU operators."""
+
+ npu_unsupported_ratio: float
+
+
+@dataclass
+class AllOperatorsSupportedOnNPU(Fact):
+ """All model's operators supported on NPU."""
+
+
+@dataclass
+class PerfMetricDiff:
+ """Performance metric difference."""
+
+ original_value: Union[int, float]
+ optimized_value: Union[int, float]
+
+ @property
+ def diff(self) -> float:
+ """Difference between metrics."""
+ if self.original_value == 0:
+ return 0
+
+ return 100 - ((self.optimized_value / self.original_value) * 100)
+
+ @property
+ def improved(self) -> bool:
+ """Return true if metric improved."""
+ return self.diff > 0
+
+ @property
+ def degraded(self) -> bool:
+ """Return true if metric degraded."""
+ return self.diff < 0
+
+ @property
+ def same(self) -> bool:
+ """Return true if metric stays the same."""
+ return self.diff == 0
+
+
+@dataclass
+class OptimizationDiff:
+ """Optimization performance impact."""
+
+ opt_type: List[OptimizationSettings]
+ opt_diffs: Dict[str, PerfMetricDiff]
+
+
+@dataclass
+class OptimizationResults(Fact):
+ """Optimization results."""
+
+ diffs: List[OptimizationDiff]
+
+
+class EthosUDataAnalyzer(FactExtractor):
+ """Ethos-U data analyzer."""
+
+ @singledispatchmethod
+ def analyze_data(self, data_item: DataItem) -> None:
+ """Analyse the data."""
+
+ @analyze_data.register
+ def analyze_operator_compatibility(self, operators: Operators) -> None:
+ """Analyse operator compatibility information."""
+ cpu_only = [op.op_type for op in operators.ops if op.cpu_only]
+ if cpu_only:
+ self.add_fact(HasCPUOnlyOperators(cpu_only))
+
+ if operators.npu_unsupported_ratio != 0:
+ self.add_fact(HasUnsupportedOnNPUOperators(operators.npu_unsupported_ratio))
+
+ if operators.npu_unsupported_ratio == 0:
+ self.add_fact(AllOperatorsSupportedOnNPU())
+
+ @analyze_data.register
+ def analyze_optimization_results(
+ self, optimization_results: OptimizationPerformanceMetrics
+ ) -> None:
+ """Analyse optimization performance metrics."""
+ optimizations = optimization_results.optimizations_perf_metrics
+ if not optimizations:
+ return
+
+ orig = optimization_results.original_perf_metrics.in_kilobytes()
+ orig_memory = orig.memory_usage
+ orig_cycles = orig.npu_cycles
+
+ diffs: List[OptimizationDiff] = []
+ for opt_type, opt_perf_metrics in optimizations:
+ opt = opt_perf_metrics.in_kilobytes()
+ opt_memory = opt.memory_usage
+ opt_cycles = opt.npu_cycles
+
+ opt_diffs: Dict[str, PerfMetricDiff] = {}
+
+ if orig_memory and opt_memory:
+ opt_diffs.update(
+ {
+ "sram": PerfMetricDiff(
+ orig_memory.sram_memory_area_size,
+ opt_memory.sram_memory_area_size,
+ ),
+ "dram": PerfMetricDiff(
+ orig_memory.dram_memory_area_size,
+ opt_memory.dram_memory_area_size,
+ ),
+ "on_chip_flash": PerfMetricDiff(
+ orig_memory.on_chip_flash_memory_area_size,
+ opt_memory.on_chip_flash_memory_area_size,
+ ),
+ "off_chip_flash": PerfMetricDiff(
+ orig_memory.off_chip_flash_memory_area_size,
+ opt_memory.off_chip_flash_memory_area_size,
+ ),
+ }
+ )
+ if orig_cycles and opt_cycles:
+ opt_diffs["npu_total_cycles"] = PerfMetricDiff(
+ orig_cycles.npu_total_cycles,
+ opt_cycles.npu_total_cycles,
+ )
+
+ diff = OptimizationDiff(opt_type=opt_type, opt_diffs=opt_diffs)
+ diffs.append(diff)
+
+ self.add_fact(OptimizationResults(diffs))
diff --git a/src/mlia/devices/ethosu/data_collection.py b/src/mlia/devices/ethosu/data_collection.py
new file mode 100644
index 0000000..291f1b8
--- /dev/null
+++ b/src/mlia/devices/ethosu/data_collection.py
@@ -0,0 +1,188 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Data collection module for Ethos-U."""
+import logging
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+from mlia.core.context import Context
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.performance import estimate_performance
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import EthosUPerformanceEstimator
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.config import get_keras_model
+from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.optimizations.select import get_optimizer
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.tools.vela_wrapper import Operators
+from mlia.tools.vela_wrapper import supported_operators
+from mlia.utils.types import is_list_of
+
+logger = logging.getLogger(__name__)
+
+
+class EthosUOperatorCompatibility(ContextAwareDataCollector):
+ """Collect operator compatibility information."""
+
+ def __init__(self, model: Path, device: EthosUConfiguration) -> None:
+ """Init operator compatibility data collector."""
+ self.model = model
+ self.device = device
+
+ def collect_data(self) -> Operators:
+ """Collect operator compatibility information."""
+ tflite_model = get_tflite_model(self.model, self.context)
+
+ logger.info("Checking operator compatibility ...")
+ ops = supported_operators(
+ Path(tflite_model.model_path), self.device.compiler_options
+ )
+ logger.info("Done\n")
+ return ops
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "ethos_u_operator_compatibility"
+
+
+class EthosUPerformance(ContextAwareDataCollector):
+ """Collect performance metrics."""
+
+ def __init__(
+ self,
+ model: Path,
+ device: EthosUConfiguration,
+ backends: Optional[List[str]] = None,
+ ) -> None:
+ """Init performance data collector."""
+ self.model = model
+ self.device = device
+ self.backends = backends
+
+ def collect_data(self) -> PerformanceMetrics:
+ """Collect model performance metrics."""
+ tflite_model = get_tflite_model(self.model, self.context)
+ estimator = EthosUPerformanceEstimator(
+ self.context,
+ self.device,
+ self.backends,
+ )
+
+ return estimator.estimate(tflite_model)
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "ethos_u_performance"
+
+
+class OptimizeModel:
+ """Helper class for model optimization."""
+
+ def __init__(
+ self, context: Context, opt_settings: List[OptimizationSettings]
+ ) -> None:
+ """Init helper."""
+ self.context = context
+ self.opt_settings = opt_settings
+
+ def __call__(self, keras_model: KerasModel) -> KerasModel:
+ """Run optimization."""
+ optimizer = get_optimizer(keras_model, self.opt_settings)
+
+ opts_as_str = ", ".join(str(opt) for opt in self.opt_settings)
+ logger.info("Applying model optimizations - [%s]", opts_as_str)
+ optimizer.apply_optimization()
+
+ model = optimizer.get_model()
+ model_path = self.context.get_model_path("optimized_model.h5")
+ save_keras_model(model, model_path)
+
+ return KerasModel(model_path)
+
+
+class EthosUOptimizationPerformance(ContextAwareDataCollector):
+ """Collect performance metrics for the optimizations."""
+
+ def __init__(
+ self,
+ model: Path,
+ device: EthosUConfiguration,
+ optimizations: List[List[dict]],
+ backends: Optional[List[str]] = None,
+ ) -> None:
+ """Init performance optimizations data collector."""
+ self.model = model
+ self.device = device
+ self.optimizations = optimizations
+ self.backends = backends
+
+ def collect_data(self) -> Optional[OptimizationPerformanceMetrics]:
+ """Collect performance metrics for the optimizations."""
+ logger.info("Estimate performance ...")
+
+ if not self.optimizations:
+ raise FunctionalityNotSupportedError(
+ reason="Unable to estimate model optimizations impact",
+ description="No optimization targets provided",
+ )
+
+ opt_settings = self._parse_optimization_params(self.optimizations)
+
+ try:
+ keras_model = get_keras_model(self.model, self.context)
+ except NotImplementedError as err:
+ raise FunctionalityNotSupportedError(
+ reason="Unable to run model optimizations",
+ description=f"{self.model} is not a Keras model and "
+ "could not be converted to a Keras model",
+ ) from err
+
+ optimizers = [OptimizeModel(self.context, opts) for opts in opt_settings]
+
+ estimator = EthosUPerformanceEstimator(
+ self.context,
+ self.device,
+ self.backends,
+ )
+ original_metrics, *optimized_metrics = estimate_performance(
+ keras_model, estimator, optimizers # type: ignore
+ )
+
+ result = OptimizationPerformanceMetrics(
+ original_perf_metrics=original_metrics,
+ optimizations_perf_metrics=list(zip(opt_settings, optimized_metrics)),
+ )
+ return result
+
+ @staticmethod
+ def _parse_optimization_params(
+ optimizations: List[List[dict]],
+ ) -> List[List[OptimizationSettings]]:
+ """Parse optimization parameters."""
+ if not is_list_of(optimizations, list):
+ raise Exception("Optimization parameters expected to be a list")
+
+ return [
+ [
+ OptimizationSettings(
+ item.get("optimization_type"), # type: ignore
+ item.get("optimization_target"), # type: ignore
+ item.get("layers_to_optimized"),
+ )
+ for item in opt_configuration
+ ]
+ for opt_configuration in optimizations
+ ]
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "ethos_u_model_optimizations"
diff --git a/src/mlia/devices/ethosu/events.py b/src/mlia/devices/ethosu/events.py
new file mode 100644
index 0000000..d5408b0
--- /dev/null
+++ b/src/mlia/devices/ethosu/events.py
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U MLIA module events."""
+from dataclasses import dataclass
+from pathlib import Path
+
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.devices.ethosu.config import EthosUConfiguration
+
+
+@dataclass
+class EthosUAdvisorStartedEvent(Event):
+ """Event with Ethos-U advisor parameters."""
+
+ model: Path
+ device: EthosUConfiguration
+
+
+class EthosUAdvisorEventHandler(EventDispatcher):
+ """Event handler for the Ethos-U inference advisor."""
+
+ def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
+ """Handle EthosUAdvisorStarted event."""
diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py
new file mode 100644
index 0000000..7a0c31c
--- /dev/null
+++ b/src/mlia/devices/ethosu/handlers.py
@@ -0,0 +1,146 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Event handler."""
+import logging
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from mlia.core._typing import OutputFormat
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import SystemEventsHandler
+from mlia.core.reporting import Reporter
+from mlia.devices.ethosu.events import EthosUAdvisorEventHandler
+from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.tools.vela_wrapper import Operators
+from mlia.utils.console import create_section_header
+
+logger = logging.getLogger(__name__)
+
+ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
+MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
+MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
+ADV_GENERATION_MSG = create_section_header("Advice Generation")
+REPORT_GENERATION_MSG = create_section_header("Report Generation")
+
+
+class WorkflowEventsHandler(SystemEventsHandler):
+ """Event handler for the system events."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ logger.info(ADV_EXECUTION_STARTED)
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+ raise event.err
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+ logger.info(MODEL_ANALYSIS_MSG)
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+ logger.info(ADV_GENERATION_MSG)
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+ logger.info("Skipped: %s", event.reason)
+
+
+class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
+ """CLI event handler."""
+
+ def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
+ """Init event handler."""
+ output_format = self.resolve_output_format(output)
+
+ self.reporter = Reporter(find_appropriate_formatter, output_format)
+ self.output = output
+ self.advice: List[Advice] = []
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinishedEvent event."""
+ self.reporter.submit(
+ self.advice,
+ show_title=False,
+ show_headers=False,
+ space="between",
+ table_style="no_borders",
+ )
+
+ self.reporter.generate_report(self.output)
+
+ if self.output is not None:
+ logger.info(REPORT_GENERATION_MSG)
+ logger.info("Report(s) and advice list saved to: %s", self.output)
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinished event."""
+ logger.info(MODEL_ANALYSIS_RESULTS_MSG)
+ self.reporter.print_delayed()
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedDataEvent event."""
+ data_item = event.data_item
+
+ if isinstance(data_item, Operators):
+ self.reporter.submit([data_item.ops, data_item], delay_print=True)
+
+ if isinstance(data_item, PerformanceMetrics):
+ self.reporter.submit(data_item, delay_print=True)
+
+ if isinstance(data_item, OptimizationPerformanceMetrics):
+ original_metrics = data_item.original_perf_metrics
+ if not data_item.optimizations_perf_metrics:
+ return
+
+ _opt_settings, optimized_metrics = data_item.optimizations_perf_metrics[0]
+
+ self.reporter.submit(
+ [original_metrics, optimized_metrics],
+ delay_print=True,
+ columns_name="Metrics",
+ title="Performance metrics",
+ space=True,
+ )
+
+ def on_advice_event(self, event: AdviceEvent) -> None:
+ """Handle Advice event."""
+ self.advice.append(event.advice)
+
+ def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
+ """Handle EthosUAdvisorStarted event."""
+ self.reporter.submit(event.device)
+
+ @staticmethod
+ def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
+ """Resolve output format based on the output name."""
+ output_format: OutputFormat = "plain_text"
+
+ if isinstance(output, str):
+ output_path = Path(output)
+ output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"}
+
+ if (suffix := output_path.suffix) in output_formats:
+ return output_formats[suffix]
+
+ return output_format
diff --git a/src/mlia/devices/ethosu/operators.py b/src/mlia/devices/ethosu/operators.py
new file mode 100644
index 0000000..ff0d99f
--- /dev/null
+++ b/src/mlia/devices/ethosu/operators.py
@@ -0,0 +1,14 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Operators module."""
+import logging
+
+from mlia.tools import vela_wrapper
+
+
+logger = logging.getLogger(__name__)
+
+
+def generate_supported_operators_report() -> None:
+ """Generate supported operators report."""
+ vela_wrapper.generate_supported_operators_report()
diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py
new file mode 100644
index 0000000..b0718a5
--- /dev/null
+++ b/src/mlia/devices/ethosu/performance.py
@@ -0,0 +1,257 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Performance estimation."""
+import logging
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import mlia.tools.aiet_wrapper as aiet
+import mlia.tools.vela_wrapper as vela
+from mlia.core.context import Context
+from mlia.core.performance import PerformanceEstimator
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.nn.tensorflow.config import ModelConfiguration
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class NPUCycles:
+ """NPU cycles metrics."""
+
+ npu_active_cycles: int
+ npu_idle_cycles: int
+ npu_total_cycles: int
+ npu_axi0_rd_data_beat_received: int
+ npu_axi0_wr_data_beat_written: int
+ npu_axi1_rd_data_beat_received: int
+
+
+BYTES_PER_KILOBYTE = 1024
+
+
+class MemorySizeType(Enum):
+ """Memory size type enumeration."""
+
+ BYTES = 0
+ KILOBYTES = 1
+
+
+@dataclass
+class MemoryUsage:
+ """Memory usage metrics."""
+
+ sram_memory_area_size: Union[int, float]
+ dram_memory_area_size: Union[int, float]
+ unknown_memory_area_size: Union[int, float]
+ on_chip_flash_memory_area_size: Union[int, float]
+ off_chip_flash_memory_area_size: Union[int, float]
+ memory_size_type: MemorySizeType = MemorySizeType.BYTES
+
+ _default_columns = [
+ "SRAM used",
+ "DRAM used",
+ "Unknown memory used",
+ "On chip flash used",
+ "Off chip flash used",
+ ]
+
+ def in_kilobytes(self) -> "MemoryUsage":
+ """Return memory usage with values in kilobytes."""
+ if self.memory_size_type == MemorySizeType.KILOBYTES:
+ return self
+
+ kilobytes = [
+ value / BYTES_PER_KILOBYTE
+ for value in [
+ self.sram_memory_area_size,
+ self.dram_memory_area_size,
+ self.unknown_memory_area_size,
+ self.on_chip_flash_memory_area_size,
+ self.off_chip_flash_memory_area_size,
+ ]
+ ]
+
+ return MemoryUsage(
+ *kilobytes, # type: ignore
+ memory_size_type=MemorySizeType.KILOBYTES,
+ )
+
+
+@dataclass
+class PerformanceMetrics:
+ """Performance metrics."""
+
+ device: EthosUConfiguration
+ npu_cycles: Optional[NPUCycles]
+ memory_usage: Optional[MemoryUsage]
+
+ def in_kilobytes(self) -> "PerformanceMetrics":
+ """Return metrics with memory usage in KiB."""
+ if self.memory_usage is None:
+ return PerformanceMetrics(self.device, self.npu_cycles, self.memory_usage)
+
+ return PerformanceMetrics(
+ self.device, self.npu_cycles, self.memory_usage.in_kilobytes()
+ )
+
+
+@dataclass
+class OptimizationPerformanceMetrics:
+ """Optimization performance metrics."""
+
+ original_perf_metrics: PerformanceMetrics
+ optimizations_perf_metrics: List[
+ Tuple[List[OptimizationSettings], PerformanceMetrics]
+ ]
+
+
+class VelaPerformanceEstimator(
+ PerformanceEstimator[Union[Path, ModelConfiguration], MemoryUsage]
+):
+ """Vela based performance estimator."""
+
+ def __init__(self, context: Context, device: EthosUConfiguration) -> None:
+ """Init Vela based performance estimator."""
+ self.context = context
+ self.device = device
+
+ def estimate(self, model: Union[Path, ModelConfiguration]) -> MemoryUsage:
+ """Estimate performance."""
+ logger.info("Getting the memory usage metrics ...")
+
+ model_path = (
+ Path(model.model_path) if isinstance(model, ModelConfiguration) else model
+ )
+
+ vela_perf_metrics = vela.estimate_performance(
+ model_path, self.device.compiler_options
+ )
+
+ memory_usage = MemoryUsage(
+ vela_perf_metrics.sram_memory_area_size,
+ vela_perf_metrics.dram_memory_area_size,
+ vela_perf_metrics.unknown_memory_area_size,
+ vela_perf_metrics.on_chip_flash_memory_area_size,
+ vela_perf_metrics.off_chip_flash_memory_area_size,
+ )
+ logger.info("Done\n")
+ return memory_usage
+
+
+class AIETPerformanceEstimator(
+ PerformanceEstimator[Union[Path, ModelConfiguration], NPUCycles]
+):
+ """AIET based performance estimator."""
+
+ def __init__(
+ self, context: Context, device: EthosUConfiguration, backend: str
+ ) -> None:
+ """Init AIET based performance estimator."""
+ self.context = context
+ self.device = device
+ self.backend = backend
+
+ def estimate(self, model: Union[Path, ModelConfiguration]) -> NPUCycles:
+ """Estimate performance."""
+ logger.info("Getting the performance metrics for '%s' ...", self.backend)
+ logger.info(
+ "WARNING: This task may require several minutes (press ctrl-c to interrupt)"
+ )
+
+ model_path = (
+ Path(model.model_path) if isinstance(model, ModelConfiguration) else model
+ )
+
+ optimized_model_path = self.context.get_model_path(
+ f"{model_path.stem}_vela.tflite"
+ )
+
+ vela.optimize_model(
+ model_path, self.device.compiler_options, optimized_model_path
+ )
+
+ model_info = aiet.ModelInfo(model_path=optimized_model_path)
+ device_info = aiet.DeviceInfo(
+ device_type=self.device.target, # type: ignore
+ mac=self.device.mac,
+ memory_mode=self.device.compiler_options.memory_mode, # type: ignore
+ )
+
+ aiet_perf_metrics = aiet.estimate_performance(
+ model_info, device_info, self.backend
+ )
+
+ npu_cycles = NPUCycles(
+ aiet_perf_metrics.npu_active_cycles,
+ aiet_perf_metrics.npu_idle_cycles,
+ aiet_perf_metrics.npu_total_cycles,
+ aiet_perf_metrics.npu_axi0_rd_data_beat_received,
+ aiet_perf_metrics.npu_axi0_wr_data_beat_written,
+ aiet_perf_metrics.npu_axi1_rd_data_beat_received,
+ )
+
+ logger.info("Done\n")
+ return npu_cycles
+
+
+class EthosUPerformanceEstimator(
+ PerformanceEstimator[Union[Path, ModelConfiguration], PerformanceMetrics]
+):
+ """Ethos-U performance estimator."""
+
+ def __init__(
+ self,
+ context: Context,
+ device: EthosUConfiguration,
+ backends: Optional[List[str]] = None,
+ ) -> None:
+ """Init performance estimator."""
+ self.context = context
+ self.device = device
+ if backends is None:
+ backends = ["Vela"] # Only Vela is always available as default
+ for backend in backends:
+ if backend != "Vela" and not aiet.is_supported(backend):
+ raise ValueError(
+ f"Unsupported backend '{backend}'. "
+ f"Only 'Vela' and {aiet.supported_backends()} are supported."
+ )
+ self.backends = set(backends)
+
+ def estimate(self, model: Union[Path, ModelConfiguration]) -> PerformanceMetrics:
+ """Estimate performance."""
+ model_path = (
+ Path(model.model_path) if isinstance(model, ModelConfiguration) else model
+ )
+
+ tflite_model = get_tflite_model(model_path, self.context)
+
+ memory_usage = None
+ npu_cycles = None
+
+ for backend in self.backends:
+ if backend == "Vela":
+ vela_estimator = VelaPerformanceEstimator(self.context, self.device)
+ memory_usage = vela_estimator.estimate(tflite_model)
+ elif backend in aiet.supported_backends():
+ aiet_estimator = AIETPerformanceEstimator(
+ self.context, self.device, backend
+ )
+ npu_cycles = aiet_estimator.estimate(tflite_model)
+ else:
+ logger.warning(
+ "Backend '%s' is not supported for Ethos-U performance "
+ "estimation.",
+ backend,
+ )
+
+ return PerformanceMetrics(self.device, npu_cycles, memory_usage)
diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py
new file mode 100644
index 0000000..d28c68f
--- /dev/null
+++ b/src/mlia/devices/ethosu/reporters.py
@@ -0,0 +1,398 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reports module."""
+from collections import defaultdict
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Tuple
+from typing import Union
+
+from mlia.core.advice_generation import Advice
+from mlia.core.reporting import BytesCell
+from mlia.core.reporting import Cell
+from mlia.core.reporting import ClockCell
+from mlia.core.reporting import Column
+from mlia.core.reporting import CompoundFormatter
+from mlia.core.reporting import CyclesCell
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import Report
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import SingleRow
+from mlia.core.reporting import Table
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.utils.console import style_improvement
+from mlia.utils.types import is_list_of
+
+
+def report_operators_stat(operators: Operators) -> Report:
+ """Return table representation for the ops stats."""
+ columns = [
+ Column("Number of operators", alias="num_of_operators"),
+ Column("Number of NPU supported operators", "num_of_npu_supported_operators"),
+ Column("Unsupported ops ratio", "npu_unsupported_ratio"),
+ ]
+ rows = [
+ (
+ operators.total_number,
+ operators.npu_supported_number,
+ Cell(
+ operators.npu_unsupported_ratio * 100,
+ fmt=Format(str_fmt="{0:.0f}%".format),
+ ),
+ )
+ ]
+
+ return SingleRow(
+ columns, rows, name="Operators statistics", alias="operators_stats"
+ )
+
+
+def report_operators(ops: List[Operator]) -> Report:
+ """Return table representation for the list of operators."""
+ columns = [
+ Column("#", only_for=["plain_text"]),
+ Column(
+ "Operator name",
+ alias="operator_name",
+ fmt=Format(wrap_width=30),
+ ),
+ Column(
+ "Operator type",
+ alias="operator_type",
+ fmt=Format(wrap_width=25),
+ ),
+ Column(
+ "Placement",
+ alias="placement",
+ fmt=Format(wrap_width=20),
+ ),
+ Column(
+ "Notes",
+ alias="notes",
+ fmt=Format(wrap_width=35),
+ ),
+ ]
+
+ rows = [
+ (
+ i + 1,
+ op.name,
+ op.op_type,
+ Cell(
+ "NPU" if (npu := op.run_on_npu.supported) else "CPU",
+ Format(style=style_improvement(npu)),
+ ),
+ Table(
+ columns=[
+ Column(
+ "Note",
+ alias="note",
+ fmt=Format(wrap_width=35),
+ )
+ ],
+ rows=[
+ (Cell(item, Format(str_fmt=lambda x: f"* {x}")),)
+ for reason in op.run_on_npu.reasons
+ for item in reason
+ if item
+ ],
+ name="Notes",
+ ),
+ )
+ for i, op in enumerate(ops)
+ ]
+
+ return Table(columns, rows, name="Operators", alias="operators")
+
+
+def report_device_details(device: EthosUConfiguration) -> Report:
+ """Return table representation for the device."""
+ compiler_config = device.resolved_compiler_config
+
+ memory_settings = [
+ ReportItem(
+ "Const mem area",
+ "const_mem_area",
+ compiler_config["const_mem_area"],
+ ),
+ ReportItem(
+ "Arena mem area",
+ "arena_mem_area",
+ compiler_config["arena_mem_area"],
+ ),
+ ReportItem(
+ "Cache mem area",
+ "cache_mem_area",
+ compiler_config["cache_mem_area"],
+ ),
+ ReportItem(
+ "Arena cache size",
+ "arena_cache_size",
+ BytesCell(compiler_config["arena_cache_size"]),
+ ),
+ ]
+
+ mem_areas_settings = [
+ ReportItem(
+ f"{mem_area_name}",
+ mem_area_name,
+ None,
+ nested_items=[
+ ReportItem(
+ "Clock scales",
+ "clock_scales",
+ mem_area_settings["clock_scales"],
+ ),
+ ReportItem(
+ "Burst length",
+ "burst_length",
+ BytesCell(mem_area_settings["burst_length"]),
+ ),
+ ReportItem(
+ "Read latency",
+ "read_latency",
+ CyclesCell(mem_area_settings["read_latency"]),
+ ),
+ ReportItem(
+ "Write latency",
+ "write_latency",
+ CyclesCell(mem_area_settings["write_latency"]),
+ ),
+ ],
+ )
+ for mem_area_name, mem_area_settings in compiler_config["memory_area"].items()
+ ]
+
+ system_settings = [
+ ReportItem(
+ "Accelerator clock",
+ "accelerator_clock",
+ ClockCell(compiler_config["core_clock"]),
+ ),
+ ReportItem(
+ "AXI0 port",
+ "axi0_port",
+ compiler_config["axi0_port"],
+ ),
+ ReportItem(
+ "AXI1 port",
+ "axi1_port",
+ compiler_config["axi1_port"],
+ ),
+ ReportItem(
+ "Memory area settings", "memory_area", None, nested_items=mem_areas_settings
+ ),
+ ]
+
+ arch_settings = [
+ ReportItem(
+ "Permanent storage mem area",
+ "permanent_storage_mem_area",
+ compiler_config["permanent_storage_mem_area"],
+ ),
+ ReportItem(
+ "Feature map storage mem area",
+ "feature_map_storage_mem_area",
+ compiler_config["feature_map_storage_mem_area"],
+ ),
+ ReportItem(
+ "Fast storage mem area",
+ "fast_storage_mem_area",
+ compiler_config["fast_storage_mem_area"],
+ ),
+ ]
+
+ return NestedReport(
+ "Device information",
+ "device",
+ [
+ ReportItem("Target", alias="target", value=device.target),
+ ReportItem("MAC", alias="mac", value=device.mac),
+ ReportItem(
+ "Memory mode",
+ alias="memory_mode",
+ value=compiler_config["memory_mode"],
+ nested_items=memory_settings,
+ ),
+ ReportItem(
+ "System config",
+ alias="system_config",
+ value=compiler_config["system_config"],
+ nested_items=system_settings,
+ ),
+ ReportItem(
+ "Architecture settings",
+ "arch_settings",
+ None,
+ nested_items=arch_settings,
+ ),
+ ],
+ )
+
+
+def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ """Convert perf metrics object into list of records."""
+ perf_metrics = [item.in_kilobytes() for item in perf_metrics]
+
+ def _cycles_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ metric_map = defaultdict(list)
+ for metrics in perf_metrics:
+ if not metrics.npu_cycles:
+ return []
+ metric_map["NPU active cycles"].append(metrics.npu_cycles.npu_active_cycles)
+ metric_map["NPU idle cycles"].append(metrics.npu_cycles.npu_idle_cycles)
+ metric_map["NPU total cycles"].append(metrics.npu_cycles.npu_total_cycles)
+
+ return [
+ (name, *(Cell(value, Format(str_fmt="12,d")) for value in values), "cycles")
+ for name, values in metric_map.items()
+ ]
+
+ def _memory_usage_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ metric_map = defaultdict(list)
+ for metrics in perf_metrics:
+ if not metrics.memory_usage:
+ return []
+ metric_map["SRAM used"].append(metrics.memory_usage.sram_memory_area_size)
+ metric_map["DRAM used"].append(metrics.memory_usage.dram_memory_area_size)
+ metric_map["Unknown memory area used"].append(
+ metrics.memory_usage.unknown_memory_area_size
+ )
+ metric_map["On-chip flash used"].append(
+ metrics.memory_usage.on_chip_flash_memory_area_size
+ )
+ metric_map["Off-chip flash used"].append(
+ metrics.memory_usage.off_chip_flash_memory_area_size
+ )
+
+ return [
+ (name, *(Cell(value, Format(str_fmt="12.2f")) for value in values), "KiB")
+ for name, values in metric_map.items()
+ if all(val > 0 for val in values)
+ ]
+
+ def _data_beats_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ metric_map = defaultdict(list)
+ for metrics in perf_metrics:
+ if not metrics.npu_cycles:
+ return []
+ metric_map["NPU AXI0 RD data beat received"].append(
+ metrics.npu_cycles.npu_axi0_rd_data_beat_received
+ )
+ metric_map["NPU AXI0 WR data beat written"].append(
+ metrics.npu_cycles.npu_axi0_wr_data_beat_written
+ )
+ metric_map["NPU AXI1 RD data beat received"].append(
+ metrics.npu_cycles.npu_axi1_rd_data_beat_received
+ )
+
+ return [
+ (name, *(Cell(value, Format(str_fmt="12,d")) for value in values), "beats")
+ for name, values in metric_map.items()
+ ]
+
+ return [
+ metrics
+ for metrics_func in (
+ _memory_usage_as_records,
+ _cycles_as_records,
+ _data_beats_as_records,
+ )
+ for metrics in metrics_func(perf_metrics)
+ ]
+
+
+def report_perf_metrics(
+ perf_metrics: Union[PerformanceMetrics, List[PerformanceMetrics]]
+) -> Report:
+ """Return comparison table for the performance metrics."""
+ if isinstance(perf_metrics, PerformanceMetrics):
+ perf_metrics = [perf_metrics]
+
+ rows = metrics_as_records(perf_metrics)
+
+ if len(perf_metrics) == 2:
+ return Table(
+ columns=[
+ Column("Metric", alias="metric", fmt=Format(wrap_width=30)),
+ Column("Original", alias="original", fmt=Format(wrap_width=15)),
+ Column("Optimized", alias="optimized", fmt=Format(wrap_width=15)),
+ Column("Unit", alias="unit", fmt=Format(wrap_width=15)),
+ Column("Improvement (%)", alias="improvement"),
+ ],
+ rows=[
+ (
+ metric,
+ original_value,
+ optimized_value,
+ unit,
+ Cell(
+ (
+ diff := 100
+ - (optimized_value.value / original_value.value * 100)
+ ),
+ Format(str_fmt="15.2f", style=style_improvement(diff > 0)),
+ )
+ if original_value.value != 0
+ else None,
+ )
+ for metric, original_value, optimized_value, unit in rows
+ ],
+ name="Performance metrics",
+ alias="performance_metrics",
+ notes="IMPORTANT: The performance figures above refer to NPU only",
+ )
+
+ return Table(
+ columns=[
+ Column("Metric", alias="metric", fmt=Format(wrap_width=30)),
+ Column("Value", alias="value", fmt=Format(wrap_width=15)),
+ Column("Unit", alias="unit", fmt=Format(wrap_width=15)),
+ ],
+ rows=rows,
+ name="Performance metrics",
+ alias="performance_metrics",
+ notes="IMPORTANT: The performance figures above refer to NPU only",
+ )
+
+
+def report_advice(advice: List[Advice]) -> Report:
+ """Generate report for the advice."""
+ return Table(
+ columns=[
+ Column("#", only_for=["plain_text"]),
+ Column("Advice", alias="advice_message"),
+ ],
+ rows=[(i + 1, a.messages) for i, a in enumerate(advice)],
+ name="Advice",
+ alias="advice",
+ )
+
+
+def find_appropriate_formatter(data: Any) -> Callable[[Any], Report]:
+ """Find appropriate formatter for the provided data."""
+ if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2):
+ return report_perf_metrics
+
+ if is_list_of(data, Advice):
+ return report_advice
+
+ if is_list_of(data, Operator):
+ return report_operators
+
+ if isinstance(data, Operators):
+ return report_operators_stat
+
+ if isinstance(data, EthosUConfiguration):
+ return report_device_details
+
+ if isinstance(data, (list, tuple)):
+ formatters = [find_appropriate_formatter(item) for item in data]
+ return CompoundFormatter(formatters)
+
+ raise Exception(f"Unable to find appropriate formatter for {data}")
diff --git a/src/mlia/nn/__init__.py b/src/mlia/nn/__init__.py
new file mode 100644
index 0000000..aac2830
--- /dev/null
+++ b/src/mlia/nn/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""NN related module."""
diff --git a/src/mlia/nn/tensorflow/__init__.py b/src/mlia/nn/tensorflow/__init__.py
new file mode 100644
index 0000000..ff061c1
--- /dev/null
+++ b/src/mlia/nn/tensorflow/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TensorFlow related module."""
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
new file mode 100644
index 0000000..d3235d7
--- /dev/null
+++ b/src/mlia/nn/tensorflow/config.py
@@ -0,0 +1,134 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Model configuration."""
+import logging
+from pathlib import Path
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Union
+
+import tensorflow as tf
+
+from mlia.core.context import Context
+from mlia.nn.tensorflow.utils import convert_tf_to_tflite
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_tf_model
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+
+logger = logging.getLogger(__name__)
+
+
+class ModelConfiguration:
+ """Base class for model configuration."""
+
+ def __init__(self, model_path: Union[str, Path]) -> None:
+ """Init model configuration instance."""
+ self.model_path = str(model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ raise NotImplementedError()
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ raise NotImplementedError()
+
+
+class KerasModel(ModelConfiguration):
+ """Keras model configuration.
+
+ Supports all models supported by Keras API: saved model, H5, HDF5
+ """
+
+ def get_keras_model(self) -> tf.keras.Model:
+ """Return associated Keras model."""
+ return tf.keras.models.load_model(self.model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ logger.info("Converting Keras to TFLite ...")
+
+ converted_model = convert_to_tflite(self.get_keras_model(), quantized)
+ logger.info("Done\n")
+
+ save_tflite_model(converted_model, tflite_model_path)
+ logger.debug(
+ "Model %s converted and saved to %s", self.model_path, tflite_model_path
+ )
+
+ return TFLiteModel(tflite_model_path)
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ return self
+
+
+class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TFLite model configuration."""
+
+ def input_details(self) -> List[Dict]:
+ """Get model's input details."""
+ interpreter = tf.lite.Interpreter(model_path=self.model_path)
+ return cast(List[Dict], interpreter.get_input_details())
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ return self
+
+
+class TfModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TensorFlow model configuration.
+
+ Supports models supported by TensorFlow API (not Keras)
+ """
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ converted_model = convert_tf_to_tflite(self.model_path, quantized)
+ save_tflite_model(converted_model, tflite_model_path)
+
+ return TFLiteModel(tflite_model_path)
+
+
+def get_model(model: Union[Path, str]) -> "ModelConfiguration":
+ """Return the model object."""
+ if is_tflite_model(model):
+ return TFLiteModel(model)
+
+ if is_keras_model(model):
+ return KerasModel(model)
+
+ if is_tf_model(model):
+ return TfModel(model)
+
+ raise Exception(
+ "The input model format is not supported"
+ "(supported formats: TFLite, Keras, TensorFlow saved model)!"
+ )
+
+
+def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
+ """Convert input model to TFLite and returns TFLiteModel object."""
+ tflite_model_path = ctx.get_model_path("converted_model.tflite")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_tflite(tflite_model_path, True)
+
+
+def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel":
+ """Convert input model to Keras and returns KerasModel object."""
+ keras_model_path = ctx.get_model_path("converted_model.h5")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_keras(keras_model_path)
diff --git a/src/mlia/nn/tensorflow/optimizations/__init__.py b/src/mlia/nn/tensorflow/optimizations/__init__.py
new file mode 100644
index 0000000..201c130
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Optimizations module."""
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
new file mode 100644
index 0000000..16d9e4b
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Clusterer that clusters unique weights per layer to a specified number.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to cluster. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module
+ cluster as experimental_cluster,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class ClusteringConfiguration(OptimizerConfiguration):
+ """Clustering configuration."""
+
+ optimization_target: int
+ layers_to_optimize: Optional[List[str]] = None
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"clustering: {self.optimization_target}"
+
+
+class Clusterer(Optimizer):
+ """
+ Clusterer class.
+
+ Used to cluster a model to a specified number of unique weights per layer.
+
+ Sample usage:
+ clusterer = Clusterer(
+ base_model,
+ optimizer_configuration)
+
+ clusterer.apply_clustering()
+ clustered_model = clusterer.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration
+ ):
+ """Init Clusterer instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _setup_clustering_params(self) -> Dict[str, Any]:
+ CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
+ return {
+ "number_of_clusters": self.optimizer_configuration.optimization_target,
+ "cluster_centroids_init": CentroidInitialization.LINEAR,
+ "preserve_sparsity": True,
+ }
+
+ def _apply_clustering_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ clustering_params = self._setup_clustering_params()
+ return experimental_cluster.cluster_weights(layer, **clustering_params)
+
+ def _init_for_clustering(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ clustering_params = self._setup_clustering_params()
+ clustered_model = experimental_cluster.cluster_weights(
+ self.model, **clustering_params
+ )
+ else:
+ clustered_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_clustering_to_layer
+ )
+
+ self.model = clustered_model
+
+ def _strip_clustering(self) -> None:
+ self.model = tfmot.clustering.keras.strip_clustering(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of clustering at once."""
+ self._init_for_clustering()
+ self._strip_clustering()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model
diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/tensorflow/optimizations/common.py
new file mode 100644
index 0000000..1dce0b2
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/common.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common items for the optimizations module."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+
+@dataclass
+class OptimizerConfiguration:
+ """Abstract optimizer configuration."""
+
+
+class Optimizer(ABC):
+ """Abstract class for the optimizer."""
+
+ @abstractmethod
+ def get_model(self) -> tf.keras.Model:
+ """Abstract method to return the model instance from the optimizer."""
+
+ @abstractmethod
+ def apply_optimization(self) -> None:
+ """Abstract method to apply optimization to the model."""
+
+ @abstractmethod
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
new file mode 100644
index 0000000..f629ba1
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Pruner to prune a model to a specified sparsity.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to prune. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
+ pruning_wrapper,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class PruningConfiguration(OptimizerConfiguration):
+ """Pruning configuration."""
+
+ optimization_target: float
+ layers_to_optimize: Optional[List[str]] = None
+ x_train: Optional[np.array] = None
+ y_train: Optional[np.array] = None
+ batch_size: int = 1
+ num_epochs: int = 1
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"pruning: {self.optimization_target}"
+
+ def has_training_data(self) -> bool:
+ """Return True if training data provided."""
+ return self.x_train is not None and self.y_train is not None
+
+
+class Pruner(Optimizer):
+ """
+ Pruner class. Used to prune a model to a specified sparsity.
+
+ Sample usage:
+ pruner = Pruner(
+ base_model,
+ optimizer_configuration)
+
+ pruner.apply_pruning()
+ pruned_model = pruner.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
+ ):
+ """Init Pruner instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ if not optimizer_configuration.has_training_data():
+ mock_x_train, mock_y_train = self._mock_train_data()
+
+ self.optimizer_configuration.x_train = mock_x_train
+ self.optimizer_configuration.y_train = mock_y_train
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _mock_train_data(self) -> Tuple[np.array, np.array]:
+ # get rid of the batch_size dimension in input and output shape
+ input_shape = tuple(x for x in self.model.input_shape if x is not None)
+ output_shape = tuple(x for x in self.model.output_shape if x is not None)
+
+ return (
+ np.random.rand(*input_shape),
+ np.random.randint(0, output_shape[-1], (output_shape[:-1])),
+ )
+
+ def _setup_pruning_params(self) -> dict:
+ return {
+ "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=0,
+ final_sparsity=self.optimizer_configuration.optimization_target,
+ begin_step=0,
+ end_step=self.optimizer_configuration.num_epochs,
+ frequency=1,
+ ),
+ }
+
+ def _apply_pruning_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ pruning_params = self._setup_pruning_params()
+ return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
+
+ def _init_for_pruning(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ pruning_params = self._setup_pruning_params()
+ prunable_model = tfmot.sparsity.keras.prune_low_magnitude(
+ self.model, **pruning_params
+ )
+ else:
+ prunable_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_pruning_to_layer
+ )
+
+ self.model = prunable_model
+
+ def _train_pruning(self) -> None:
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
+ self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ # Model callbacks
+ callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
+
+ # Fitting data
+ self.model.fit(
+ self.optimizer_configuration.x_train,
+ self.optimizer_configuration.y_train,
+ batch_size=self.optimizer_configuration.batch_size,
+ epochs=self.optimizer_configuration.num_epochs,
+ callbacks=callbacks,
+ verbose=0,
+ )
+
+ def _assert_sparsity_reached(self) -> None:
+ for layer in self.model.layers:
+ if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
+ continue
+
+ for weight in layer.layer.get_prunable_weights():
+ nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
+ all_weights = tf.keras.backend.get_value(weight).size
+
+ np.testing.assert_approx_equal(
+ self.optimizer_configuration.optimization_target,
+ 1 - nonzero_weights / all_weights,
+ significant=2,
+ )
+
+ def _strip_pruning(self) -> None:
+ self.model = tfmot.sparsity.keras.strip_pruning(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of pruning sequentially."""
+ self._init_for_pruning()
+ self._train_pruning()
+ self._assert_sparsity_reached()
+ self._strip_pruning()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model
diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py
new file mode 100644
index 0000000..1b0c755
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/select.py
@@ -0,0 +1,179 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for optimization selection."""
+import math
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import tensorflow as tf
+
+from mlia.core.errors import ConfigurationError
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.utils.types import is_list_of
+
+
+class OptimizationSettings(NamedTuple):
+ """Optimization settings."""
+
+ optimization_type: str
+ optimization_target: Union[int, float]
+ layers_to_optimize: Optional[List[str]]
+
+ @staticmethod
+ def create_from(
+ optimizer_params: List[Tuple[str, float]],
+ layers_to_optimize: Optional[List[str]] = None,
+ ) -> List["OptimizationSettings"]:
+ """Create optimization settings from the provided parameters."""
+ return [
+ OptimizationSettings(
+ optimization_type=opt_type,
+ optimization_target=opt_target,
+ layers_to_optimize=layers_to_optimize,
+ )
+ for opt_type, opt_target in optimizer_params
+ ]
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ return f"{self.optimization_type}: {self.optimization_target}"
+
+ def next_target(self) -> "OptimizationSettings":
+ """Return next optimization target."""
+ if self.optimization_type == "pruning":
+ next_target = round(min(self.optimization_target + 0.1, 0.9), 2)
+ return OptimizationSettings(
+ self.optimization_type, next_target, self.layers_to_optimize
+ )
+
+ if self.optimization_type == "clustering":
+ # return next lowest power of two for clustering
+ next_target = math.log(self.optimization_target, 2)
+ if next_target.is_integer():
+ next_target -= 1
+
+ next_target = max(int(2 ** int(next_target)), 4)
+ return OptimizationSettings(
+ self.optimization_type, next_target, self.layers_to_optimize
+ )
+
+ raise Exception(f"Unknown optimization type {self.optimization_type}")
+
+
+class MultiStageOptimizer(Optimizer):
+ """Optimizer with multiply stages."""
+
+ def __init__(
+ self,
+ model: tf.keras.Model,
+ optimizations: List[OptimizerConfiguration],
+ ) -> None:
+ """Init MultiStageOptimizer instance."""
+ self.model = model
+ self.optimizations = optimizations
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return " - ".join(str(opt) for opt in self.optimizations)
+
+ def get_model(self) -> tf.keras.Model:
+ """Return optimized model."""
+ return self.model
+
+ def apply_optimization(self) -> None:
+ """Apply optimization to the model."""
+ for config in self.optimizations:
+ optimizer = get_optimizer(self.model, config)
+ optimizer.apply_optimization()
+ self.model = optimizer.get_model()
+
+
+def get_optimizer(
+ model: Union[tf.keras.Model, KerasModel],
+ config: Union[
+ OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings]
+ ],
+) -> Optimizer:
+ """Get optimizer for provided configuration."""
+ if isinstance(model, KerasModel):
+ model = model.get_keras_model()
+
+ if isinstance(config, PruningConfiguration):
+ return Pruner(model, config)
+
+ if isinstance(config, ClusteringConfiguration):
+ return Clusterer(model, config)
+
+ if isinstance(config, OptimizationSettings) or is_list_of(
+ config, OptimizationSettings
+ ):
+ return _get_optimizer(model, config) # type: ignore
+
+ raise ConfigurationError(f"Unknown optimization configuration {config}")
+
+
+def _get_optimizer(
+ model: tf.keras.Model,
+ optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]],
+) -> Optimizer:
+ if isinstance(optimization_settings, OptimizationSettings):
+ optimization_settings = [optimization_settings]
+
+ optimizer_configs = []
+ for opt_type, opt_target, layers_to_optimize in optimization_settings:
+ _check_optimizer_params(opt_type, opt_target)
+
+ opt_config = _get_optimizer_configuration(
+ opt_type, opt_target, layers_to_optimize
+ )
+ optimizer_configs.append(opt_config)
+
+ if len(optimizer_configs) == 1:
+ return get_optimizer(model, optimizer_configs[0])
+
+ return MultiStageOptimizer(model, optimizer_configs)
+
+
+def _get_optimizer_configuration(
+ optimization_type: str,
+ optimization_target: Union[int, float],
+ layers_to_optimize: Optional[List[str]] = None,
+) -> OptimizerConfiguration:
+ """Get optimizer configuration for provided parameters."""
+ _check_optimizer_params(optimization_type, optimization_target)
+
+ opt_type = optimization_type.lower()
+ if opt_type == "pruning":
+ return PruningConfiguration(optimization_target, layers_to_optimize)
+
+ if opt_type == "clustering":
+ # make sure an integer is given as clustering target
+ if optimization_target == int(optimization_target):
+ return ClusteringConfiguration(int(optimization_target), layers_to_optimize)
+
+ raise ConfigurationError(
+ "Optimization target should be a positive integer. "
+ f"Optimization target provided: {optimization_target}"
+ )
+
+ raise ConfigurationError(f"Unsupported optimization type: {optimization_type}")
+
+
+def _check_optimizer_params(
+ optimization_type: str, optimization_target: Union[int, float]
+) -> None:
+ """Check optimizer params."""
+ if not optimization_target:
+ raise ConfigurationError("Optimization target is not provided")
+
+ if not optimization_type:
+ raise ConfigurationError("Optimization type is not provided")
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
new file mode 100644
index 0000000..b29fab3
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -0,0 +1,296 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class TFLiteMetrics to calculate metrics from a TFLite file.
+
+These metrics include:
+* Sparsity (per layer and overall)
+* Unique weights (clusters) (per layer)
+* gzip compression ratio
+"""
+import os
+from enum import Enum
+from pprint import pprint
+from typing import Any
+from typing import List
+from typing import Optional
+
+import numpy as np
+import tensorflow as tf
+
+DEFAULT_IGNORE_LIST = [
+ "relu",
+ "pooling",
+ "reshape",
+ "identity",
+ "input",
+ "add",
+ "flatten",
+ "StatefulPartitionedCall",
+ "bias",
+]
+
+
+def calculate_num_unique_weights(weights: np.array) -> int:
+ """Calculate the number of unique weights in the given weights."""
+ num_unique_weights = len(np.unique(weights))
+ return num_unique_weights
+
+
+def calculate_num_unique_weights_per_axis(weights: np.array, axis: int) -> List[int]:
+ """Calculate unique weights per quantization axis."""
+ # Make quantized dimension the first dimension
+ weights_trans = np.swapaxes(weights, 0, axis)
+ num_uniques_weights = [
+ calculate_num_unique_weights(weights_trans[i])
+ for i in range(weights_trans.shape[0])
+ ]
+ assert num_uniques_weights
+ return num_uniques_weights
+
+
+class SparsityAccumulator:
+ """Helper class to accumulate sparsity over several layers."""
+
+ def __init__(self) -> None:
+ """Create an empty accumulator."""
+ self.total_non_zero_weights: int = 0
+ self.total_weights: int = 0
+
+ def __call__(self, weights: np.array) -> None:
+ """Update the accumulator with the given weights."""
+ non_zero_weights = np.count_nonzero(weights)
+ self.total_non_zero_weights += non_zero_weights
+ self.total_weights += weights.size
+
+ def sparsity(self) -> float:
+ """Calculate the sparsity for all added weights."""
+ return 1.0 - self.total_non_zero_weights / float(self.total_weights)
+
+
+def calculate_sparsity(
+ weights: np.array, accumulator: Optional[SparsityAccumulator] = None
+) -> float:
+ """
+ Calculate the sparsity for the given weights.
+
+ If the accumulator is passed, it is updated as well.
+ """
+ non_zero_weights = np.count_nonzero(weights)
+ sparsity = 1.0 - float(non_zero_weights) / float(weights.size)
+ if accumulator is not None:
+ accumulator(weights)
+ return sparsity
+
+
+class ReportClusterMode(Enum):
+ """Specifies the way cluster values are aggregated and reported."""
+
+ NUM_CLUSTERS_HISTOGRAM = (
+ "A histogram of the number of clusters per axis. "
+ "I.e. the number of clusters is the index of the list (the bin) and "
+ "the value is the number of axes that have this number of clusters. "
+ "The first bin is 1."
+ )
+ NUM_CLUSTERS_PER_AXIS = "Number of clusters (unique weights) per axis."
+ NUM_CLUSTERS_MIN_MAX = "Min/max number of clusters over all axes."
+
+
+class TFLiteMetrics:
+ """Helper class to calculate metrics from a TFLite file.
+
+ Metrics include:
+ * sparsity (per-layer and overall)
+ * number of unique weights (clusters) per layer
+ * File compression via gzip
+ """
+
+ def __init__(
+ self, tflite_file: str, ignore_list: Optional[List[str]] = None
+ ) -> None:
+ """Load the TFLite file and filter layers."""
+ self.tflite_file = tflite_file
+ if ignore_list is None:
+ ignore_list = DEFAULT_IGNORE_LIST
+ self.ignore_list = [ignore.casefold() for ignore in ignore_list]
+ # Initialize the TFLite interpreter with the model file
+ self.interpreter = tf.lite.Interpreter(model_path=tflite_file)
+ self.interpreter.allocate_tensors()
+ self.details: dict = {}
+
+ def ignore(details: dict) -> bool:
+ name = details["name"].casefold()
+ if not name:
+ return True
+ for to_ignore in self.ignore_list:
+ if to_ignore in name:
+ return True
+ return False
+
+ self.filtered_details = {
+ details["name"]: details
+ for details in self.interpreter.get_tensor_details()
+ if not ignore(details)
+ }
+
+ def get_tensor(self, details: dict) -> Any:
+ """Return the weights/tensor specified in the given details map."""
+ return self.interpreter.tensor(details["index"])()
+
+ def sparsity_per_layer(self) -> dict:
+ """Return a dict of layer name and sparsity value."""
+ sparsity = {
+ name: calculate_sparsity(self.get_tensor(details))
+ for name, details in self.filtered_details.items()
+ }
+ return sparsity
+
+ def sparsity_overall(self) -> float:
+ """Return an instance of SparsityAccumulator for the filtered layers."""
+ acc = SparsityAccumulator()
+ for details in self.filtered_details.values():
+ acc(self.get_tensor(details))
+ return acc.sparsity()
+
+ def calc_num_clusters_per_axis(self, details: dict) -> List[int]:
+ """Calculate number of clusters per axis."""
+ quant_params = details["quantization_parameters"]
+ per_axis = len(quant_params["zero_points"]) > 1
+ if per_axis:
+ # Calculate unique weights along quantization axis
+ axis = quant_params["quantized_dimension"]
+ return calculate_num_unique_weights_per_axis(self.get_tensor(details), axis)
+
+ # Calculate unique weights over all axes/dimensions
+ return [calculate_num_unique_weights(self.get_tensor(details))]
+
+ def num_unique_weights(self, mode: ReportClusterMode) -> dict:
+ """Return a dict of layer name and number of unique weights."""
+ aggregation_func = None
+ if mode == ReportClusterMode.NUM_CLUSTERS_PER_AXIS:
+ aggregation_func = self.calc_num_clusters_per_axis
+ elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX:
+
+ def cluster_min_max(details: dict) -> List[int]:
+ num_clusters = self.calc_num_clusters_per_axis(details)
+ return [min(num_clusters), max(num_clusters)]
+
+ aggregation_func = cluster_min_max
+ elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM:
+
+ def cluster_hist(details: dict) -> List[int]:
+ num_clusters = self.calc_num_clusters_per_axis(details)
+ max_num = max(num_clusters)
+ hist = [0] * (max_num)
+ for num in num_clusters:
+ idx = num - 1
+ hist[idx] += 1
+ return hist
+
+ aggregation_func = cluster_hist
+ else:
+ raise NotImplementedError(
+ "ReportClusterMode '{}' not implemented.".format(mode)
+ )
+ uniques = {
+ name: aggregation_func(details)
+ for name, details in self.filtered_details.items()
+ }
+ return uniques
+
+ @staticmethod
+ def _prettify_name(name: str) -> str:
+ if name.startswith("model"):
+ return name.split("/", 1)[1]
+ return name
+
+ def summary(
+ self,
+ report_sparsity: bool,
+ report_cluster_mode: ReportClusterMode = None,
+ max_num_clusters: int = 32,
+ verbose: bool = False,
+ ) -> None:
+ """Print a summary of all the model information."""
+ print("Model file: {}".format(self.tflite_file))
+ print("#" * 80)
+ print(" " * 28 + "### TFLITE SUMMARY ###")
+ print("File: {}".format(os.path.abspath(self.tflite_file)))
+ print("Input(s):")
+ self._print_in_outs(self.interpreter.get_input_details(), verbose)
+ print("Output(s):")
+ self._print_in_outs(self.interpreter.get_output_details(), verbose)
+ print()
+ header = ["Layer", "Index", "Type", "Num weights"]
+ if report_sparsity:
+ header.append("Sparsity")
+ rows = []
+ sparsity_accumulator = SparsityAccumulator()
+ for details in self.filtered_details.values():
+ name = details["name"]
+ weights = self.get_tensor(details)
+ row = [
+ self._prettify_name(name),
+ details["index"],
+ weights.dtype,
+ weights.size,
+ ]
+ if report_sparsity:
+ sparsity = calculate_sparsity(weights, sparsity_accumulator)
+ row.append("{:.2f}".format(sparsity))
+ rows.append(row)
+ if verbose:
+ # Print cluster centroids
+ print("{} cluster centroids:".format(name))
+ pprint(np.unique(weights))
+ # Add summary/overall values
+ empty_row = ["" for _ in range(len(header))]
+ summary_row = empty_row
+ summary_row[header.index("Layer")] = "=> OVERALL"
+ summary_row[header.index("Num weights")] = str(
+ sparsity_accumulator.total_weights
+ )
+ if report_sparsity:
+ summary_row[header.index("Sparsity")] = "{:.2f}".format(
+ sparsity_accumulator.sparsity()
+ )
+ rows.append(summary_row)
+ # Report detailed cluster info
+ if report_cluster_mode is not None:
+ print()
+ self._print_cluster_details(report_cluster_mode, max_num_clusters)
+ print("#" * 80)
+
+ def _print_cluster_details(
+ self, report_cluster_mode: ReportClusterMode, max_num_clusters: int
+ ) -> None:
+ print("{}:\n{}".format(report_cluster_mode.name, report_cluster_mode.value))
+ num_clusters = self.num_unique_weights(report_cluster_mode)
+ if (
+ report_cluster_mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM
+ and max_num_clusters > 0
+ ):
+ # Only show cluster histogram if there are not more than
+ # max_num_clusters. This is a workaround for not showing a huge
+ # histogram for unclustered layers.
+ for name, value in num_clusters.items():
+ if len(value) > max_num_clusters:
+ num_clusters[name] = "More than {} unique values.".format(
+ max_num_clusters
+ )
+ for name, nums in num_clusters.items():
+ print("- {}: {}".format(self._prettify_name(name), nums))
+
+ @staticmethod
+ def _print_in_outs(ios: List[dict], verbose: bool = False) -> None:
+ for item in ios:
+ if verbose:
+ pprint(item)
+ else:
+ print(
+ "- {} ({}): {}".format(
+ item["name"],
+ np.dtype(item["dtype"]).name,
+ item["shape"],
+ )
+ )
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
new file mode 100644
index 0000000..4abf6cd
--- /dev/null
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+"""Collection of useful functions for optimizations."""
+import logging
+from pathlib import Path
+from typing import Callable
+from typing import Iterable
+from typing import Union
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.lite.python.interpreter import Interpreter
+
+from mlia.utils.logging import redirect_output
+
+
+def representative_dataset(model: tf.keras.Model) -> Callable:
+ """Sample dataset used for quantization."""
+ input_shape = model.input_shape
+
+ def dataset() -> Iterable:
+ for _ in range(100):
+ if input_shape[0] != 1:
+ raise Exception("Only the input batch_size=1 is supported!")
+ data = np.random.rand(*input_shape)
+ yield [data.astype(np.float32)]
+
+ return dataset
+
+
+def get_tf_tensor_shape(model: str) -> list:
+ """Get input shape for the TensorFlow tensor model."""
+ # Loading the model
+ loaded = tf.saved_model.load(model)
+ # The model signature must have 'serving_default' as a key
+ if "serving_default" not in loaded.signatures.keys():
+ raise Exception(
+ "Unsupported TensorFlow model signature, must have 'serving_default'"
+ )
+ # Get the signature inputs
+ inputs_tensor_info = loaded.signatures["serving_default"].inputs
+ dims = []
+ # Build a list of all inputs shape sizes
+ for input_key in inputs_tensor_info:
+ if input_key.get_shape():
+ dims.extend(list(input_key.get_shape()))
+ return dims
+
+
+def representative_tf_dataset(model: str) -> Callable:
+ """Sample dataset used for quantization."""
+ if not (input_shape := get_tf_tensor_shape(model)):
+ raise Exception("Unable to get input shape")
+
+ def dataset() -> Iterable:
+ for _ in range(100):
+ data = np.random.rand(*input_shape)
+ yield [data.astype(np.float32)]
+
+ return dataset
+
+
+def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
+ """Convert Keras model to TFLite."""
+ if not isinstance(model, tf.keras.Model):
+ raise Exception("Invalid model type")
+
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(model)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+
+def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
+ """Convert TensorFlow model to TFLite."""
+ if not isinstance(model, str):
+ raise Exception("Invalid model type")
+
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_tf_dataset(model)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+
+def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None:
+ """Save Keras model at provided path."""
+ # Checkpoint: saving the optimizer is necessary.
+ model.save(save_path, include_optimizer=True)
+
+
+def save_tflite_model(
+ model: tf.lite.TFLiteConverter, save_path: Union[str, Path]
+) -> None:
+ """Save TFLite model at provided path."""
+ with open(save_path, "wb") as file:
+ file.write(model)
+
+
+def is_tflite_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by TFLite API.
+
+ TFLite model is indicated by the model file extension .tflite
+ """
+ model_path = Path(model)
+ return model_path.suffix == ".tflite"
+
+
+def is_keras_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by Keras API.
+
+ Keras model is indicated by:
+ 1. if it's a directory (meaning saved model),
+ it should contain keras_metadata.pb file
+ 2. or if the model file extension is .h5/.hdf5
+ """
+ model_path = Path(model)
+
+ if model_path.is_dir():
+ return (model_path / "keras_metadata.pb").exists()
+ return model_path.suffix in (".h5", ".hdf5")
+
+
+def is_tf_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by TensorFlow API.
+
+ TensorFlow model is indicated if its directory (meaning saved model)
+ doesn't contain keras_metadata.pb file
+ """
+ model_path = Path(model)
+ return model_path.is_dir() and not is_keras_model(model)
diff --git a/src/mlia/resources/aiet/applications/APPLICATIONS.txt b/src/mlia/resources/aiet/applications/APPLICATIONS.txt
new file mode 100644
index 0000000..09127f8
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/APPLICATIONS.txt
@@ -0,0 +1,6 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This directory contains the Generic Inference Runner application packages for AIET
+
+Each package should contain its own aiet-config.json file
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
new file mode 100644
index 0000000..757ccd1
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
@@ -0,0 +1,18 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55"
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..4c50e1f
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
new file mode 100644
index 0000000..cb7e113
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55 SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..850e2eb
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json
new file mode 100644
index 0000000..d524f64
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U65 Dedicated SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..f881bb8
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
new file mode 100644
index 0000000..2cbab70
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..846ee33
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
new file mode 100644
index 0000000..01bec74
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55 SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..e3eab97
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/systems/SYSTEMS.txt b/src/mlia/resources/aiet/systems/SYSTEMS.txt
new file mode 100644
index 0000000..bc27e73
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/SYSTEMS.txt
@@ -0,0 +1,10 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This directory contains the configuration files of the systems for the AIET
+middleware.
+
+Supported systems:
+
+ * FVP Corstone-300 Ecosystem
+ * FVP Corstone-310 Ecosystem
diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json
new file mode 100644
index 0000000..3ffa548
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json
@@ -0,0 +1,80 @@
+[
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55",
+ "description": "Cortex-M55 and Ethos-U55 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65",
+ "description": "Cortex-M55 and Ethos-U65 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U65",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U65"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U65 configuration - the number represents MACs per cycle.",
+ "values": [
+ "256",
+ "512"
+ ],
+ "default_value": "512",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json
new file mode 100644
index 0000000..6d6785d
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json
@@ -0,0 +1,80 @@
+[
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55",
+ "description": "Cortex-M55 and Ethos-U55 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "FVP_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65",
+ "description": "Cortex-M55 and Ethos-U65 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U65",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U65"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "FVP_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U65 configuration - the number represents MACs per cycle.",
+ "values": [
+ "256",
+ "512"
+ ],
+ "default_value": "512",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json
new file mode 100644
index 0000000..dbc2622
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json
@@ -0,0 +1,42 @@
+[
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55",
+ "description": "Cortex-M85 and Ethos-U55 functional model implementations based on Corstone-310 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M85+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "/opt/VHT/VHT_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json
new file mode 100644
index 0000000..7aa3b0a
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json
@@ -0,0 +1,42 @@
+[
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55",
+ "description": "Cortex-M85 and Ethos-U55 functional model implementations based on Corstone-310 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M85+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "FVP_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/profiles.json b/src/mlia/resources/profiles.json
new file mode 100644
index 0000000..4493d7b
--- /dev/null
+++ b/src/mlia/resources/profiles.json
@@ -0,0 +1,20 @@
+{
+ "ethos-u55-256": {
+ "target": "ethos-u55",
+ "mac": 256,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram"
+ },
+ "ethos-u55-128": {
+ "target": "ethos-u55",
+ "mac": 128,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram"
+ },
+ "ethos-u65-512": {
+ "target": "ethos-u65",
+ "mac": 512,
+ "system_config": "Ethos_U65_High_End",
+ "memory_mode": "Dedicated_Sram"
+ }
+}
diff --git a/src/mlia/resources/profiles.json.license b/src/mlia/resources/profiles.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/profiles.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/vela/vela.ini b/src/mlia/resources/vela/vela.ini
new file mode 100644
index 0000000..382820d
--- /dev/null
+++ b/src/mlia/resources/vela/vela.ini
@@ -0,0 +1,75 @@
+; SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates.
+; SPDX-License-Identifier: Apache-2.0
+
+; -----------------------------------------------------------------------------
+; Vela configuration file
+; -----------------------------------------------------------------------------
+
+; System Configuration
+
+; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U55_High_End_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.125
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 Embedded: SRAM (8 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U65_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.0625
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
+[System_Config.Ethos_U65_High_End]
+core_clock=1e9
+axi0_port=Sram
+axi1_port=Dram
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+Dram_clock_scale=0.234375
+Dram_burst_length=128
+Dram_read_latency=500
+Dram_write_latency=250
+
+; -----------------------------------------------------------------------------
+
+; Memory Mode
+
+; SRAM Only: only one AXI port is used and the SRAM is used for all storage
+[Memory_Mode.Sram_Only]
+const_mem_area=Axi0
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+
+; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
+; The non-SRAM memory is assumed to be read-only
+[Memory_Mode.Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+
+; Dedicated SRAM: the SRAM (384KB) is only for use by the Ethos-U
+; The non-SRAM memory is assumed to be read-writeable
+[Memory_Mode.Dedicated_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi1
+cache_mem_area=Axi0
+arena_cache_size=393216
diff --git a/src/mlia/tools/__init__.py b/src/mlia/tools/__init__.py
new file mode 100644
index 0000000..184e966
--- /dev/null
+++ b/src/mlia/tools/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tools module."""
diff --git a/src/mlia/tools/aiet_wrapper.py b/src/mlia/tools/aiet_wrapper.py
new file mode 100644
index 0000000..73e82ee
--- /dev/null
+++ b/src/mlia/tools/aiet_wrapper.py
@@ -0,0 +1,435 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for AIET integration."""
+import logging
+import re
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+
+from aiet.backend.application import get_available_applications
+from aiet.backend.application import install_application
+from aiet.backend.system import get_available_systems
+from aiet.backend.system import install_system
+from mlia.utils.proc import CommandExecutor
+from mlia.utils.proc import OutputConsumer
+from mlia.utils.proc import RunningCommand
+
+
+logger = logging.getLogger(__name__)
+
+# Mapping backend -> device_type -> system_name
+_SUPPORTED_SYSTEMS = {
+ "Corstone-300": {
+ "ethos-u55": "Corstone-300: Cortex-M55+Ethos-U55",
+ "ethos-u65": "Corstone-300: Cortex-M55+Ethos-U65",
+ },
+ "Corstone-310": {
+ "ethos-u55": "Corstone-310: Cortex-M85+Ethos-U55",
+ },
+}
+
+# Mapping system_name -> memory_mode -> application
+_SYSTEM_TO_APP_MAP = {
+ "Corstone-300: Cortex-M55+Ethos-U55": {
+ "Sram": "Generic Inference Runner: Ethos-U55 SRAM",
+ "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ },
+ "Corstone-300: Cortex-M55+Ethos-U65": {
+ "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ "Dedicated_Sram": "Generic Inference Runner: Ethos-U65 Dedicated SRAM",
+ },
+ "Corstone-310: Cortex-M85+Ethos-U55": {
+ "Sram": "Generic Inference Runner: Ethos-U55 SRAM",
+ "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ },
+}
+
+
+def get_system_name(backend: str, device_type: str) -> str:
+ """Get the AIET system name for the given backend and device type."""
+ return _SUPPORTED_SYSTEMS[backend][device_type]
+
+
+def is_supported(backend: str, device_type: Optional[str] = None) -> bool:
+ """Check if the backend (and optionally device type) is supported."""
+ if device_type is None:
+ return backend in _SUPPORTED_SYSTEMS
+
+ try:
+ get_system_name(backend, device_type)
+ return True
+ except KeyError:
+ return False
+
+
+def supported_backends() -> List[str]:
+ """Get a list of all backends supported by the AIET wrapper."""
+ return list(_SUPPORTED_SYSTEMS.keys())
+
+
+def get_all_system_names(backend: str) -> List[str]:
+ """Get all systems supported by the backend."""
+ return list(_SUPPORTED_SYSTEMS.get(backend, {}).values())
+
+
+def get_all_application_names(backend: str) -> List[str]:
+ """Get all applications supported by the backend."""
+ app_set = {
+ app
+ for sys in get_all_system_names(backend)
+ for app in _SYSTEM_TO_APP_MAP[sys].values()
+ }
+ return list(app_set)
+
+
+@dataclass
+class DeviceInfo:
+ """Device information."""
+
+ device_type: Literal["ethos-u55", "ethos-u65"]
+ mac: int
+ memory_mode: Literal["Sram", "Shared_Sram", "Dedicated_Sram"]
+
+
+@dataclass
+class ModelInfo:
+ """Model info."""
+
+ model_path: Path
+
+
+@dataclass
+class PerformanceMetrics:
+ """Performance metrics parsed from generic inference output."""
+
+ npu_active_cycles: int
+ npu_idle_cycles: int
+ npu_total_cycles: int
+ npu_axi0_rd_data_beat_received: int
+ npu_axi0_wr_data_beat_written: int
+ npu_axi1_rd_data_beat_received: int
+
+
+@dataclass
+class ExecutionParams:
+ """Application execution params."""
+
+ application: str
+ system: str
+ application_params: List[str]
+ system_params: List[str]
+ deploy_params: List[str]
+
+
+class AIETLogWriter(OutputConsumer):
+ """Redirect AIET command output to the logger."""
+
+ def feed(self, line: str) -> None:
+ """Process line from the output."""
+ logger.debug(line.strip())
+
+
+class GenericInferenceOutputParser(OutputConsumer):
+ """Generic inference app output parser."""
+
+ PATTERNS = {
+ name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns)
+ for name, patterns in (
+ (
+ "npu_active_cycles",
+ (
+ r"NPU ACTIVE cycles: (?P<value>\d+)",
+ r"NPU ACTIVE: (?P<value>\d+) cycles",
+ ),
+ ),
+ (
+ "npu_idle_cycles",
+ (
+ r"NPU IDLE cycles: (?P<value>\d+)",
+ r"NPU IDLE: (?P<value>\d+) cycles",
+ ),
+ ),
+ (
+ "npu_total_cycles",
+ (
+ r"NPU TOTAL cycles: (?P<value>\d+)",
+ r"NPU TOTAL: (?P<value>\d+) cycles",
+ ),
+ ),
+ (
+ "npu_axi0_rd_data_beat_received",
+ (
+ r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P<value>\d+)",
+ r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P<value>\d+) beats",
+ ),
+ ),
+ (
+ "npu_axi0_wr_data_beat_written",
+ (
+ r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P<value>\d+)",
+ r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P<value>\d+) beats",
+ ),
+ ),
+ (
+ "npu_axi1_rd_data_beat_received",
+ (
+ r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P<value>\d+)",
+ r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P<value>\d+) beats",
+ ),
+ ),
+ )
+ }
+
+ def __init__(self) -> None:
+ """Init generic inference output parser instance."""
+ self.result: Dict = {}
+
+ def feed(self, line: str) -> None:
+ """Feed new line to the parser."""
+ for name, patterns in self.PATTERNS.items():
+ for pattern in patterns:
+ match = pattern.search(line)
+
+ if match:
+ self.result[name] = int(match["value"])
+ return
+
+ def is_ready(self) -> bool:
+ """Return true if all expected data has been parsed."""
+ return self.result.keys() == self.PATTERNS.keys()
+
+ def missed_keys(self) -> List[str]:
+ """Return list of the keys that have not been found in the output."""
+ return sorted(self.PATTERNS.keys() - self.result.keys())
+
+
+class AIETRunner:
+ """AIET runner."""
+
+ def __init__(self, executor: CommandExecutor) -> None:
+ """Init AIET runner instance."""
+ self.executor = executor
+
+ @staticmethod
+ def get_installed_systems() -> List[str]:
+ """Get list of the installed systems."""
+ return [system.name for system in get_available_systems()]
+
+ @staticmethod
+ def get_installed_applications(system: Optional[str] = None) -> List[str]:
+ """Get list of the installed application."""
+ return [
+ app.name
+ for app in get_available_applications()
+ if system is None or app.can_run_on(system)
+ ]
+
+ def is_application_installed(self, application: str, system: str) -> bool:
+ """Return true if requested application installed."""
+ return application in self.get_installed_applications(system)
+
+ def is_system_installed(self, system: str) -> bool:
+ """Return true if requested system installed."""
+ return system in self.get_installed_systems()
+
+ def systems_installed(self, systems: List[str]) -> bool:
+ """Check if all provided systems are installed."""
+ if not systems:
+ return False
+
+ installed_systems = self.get_installed_systems()
+ return all(system in installed_systems for system in systems)
+
+ def applications_installed(self, applications: List[str]) -> bool:
+ """Check if all provided applications are installed."""
+ if not applications:
+ return False
+
+ installed_apps = self.get_installed_applications()
+ return all(app in installed_apps for app in applications)
+
+ def all_installed(self, systems: List[str], apps: List[str]) -> bool:
+ """Check if all provided artifacts are installed."""
+ return self.systems_installed(systems) and self.applications_installed(apps)
+
+ @staticmethod
+ def install_system(system_path: Path) -> None:
+ """Install system."""
+ install_system(system_path)
+
+ @staticmethod
+ def install_application(app_path: Path) -> None:
+ """Install application."""
+ install_application(app_path)
+
+ def run_application(self, execution_params: ExecutionParams) -> RunningCommand:
+ """Run requested application."""
+ command = [
+ "aiet",
+ "application",
+ "run",
+ "-n",
+ execution_params.application,
+ "-s",
+ execution_params.system,
+ *self._params("-p", execution_params.application_params),
+ *self._params("--system-param", execution_params.system_params),
+ *self._params("--deploy", execution_params.deploy_params),
+ ]
+
+ return self._submit(command)
+
+ @staticmethod
+ def _params(name: str, params: List[str]) -> List[str]:
+ return [p for item in [(name, param) for param in params] for p in item]
+
+ def _submit(self, command: List[str]) -> RunningCommand:
+ """Submit command for the execution."""
+ logger.debug("Submit command %s", " ".join(command))
+ return self.executor.submit(command)
+
+
+class GenericInferenceRunner(ABC):
+ """Abstract class for generic inference runner."""
+
+ def __init__(self, aiet_runner: AIETRunner):
+ """Init generic inference runner instance."""
+ self.aiet_runner = aiet_runner
+ self.running_inference: Optional[RunningCommand] = None
+
+ def run(
+ self, model_info: ModelInfo, output_consumers: List[OutputConsumer]
+ ) -> None:
+ """Run generic inference for the provided device/model."""
+ execution_params = self.get_execution_params(model_info)
+
+ self.running_inference = self.aiet_runner.run_application(execution_params)
+ self.running_inference.output_consumers = output_consumers
+ self.running_inference.consume_output()
+
+ def stop(self) -> None:
+ """Stop running inference."""
+ if self.running_inference is None:
+ return
+
+ self.running_inference.stop()
+
+ @abstractmethod
+ def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams:
+ """Get execution params for the provided model."""
+
+ def __enter__(self) -> "GenericInferenceRunner":
+ """Enter context."""
+ return self
+
+ def __exit__(self, *_args: Any) -> None:
+ """Exit context."""
+ self.stop()
+
+ def check_system_and_application(self, system_name: str, app_name: str) -> None:
+ """Check if requested system and application installed."""
+ if not self.aiet_runner.is_system_installed(system_name):
+ raise Exception(f"System {system_name} is not installed")
+
+ if not self.aiet_runner.is_application_installed(app_name, system_name):
+ raise Exception(
+ f"Application {app_name} for the system {system_name} "
+ "is not installed"
+ )
+
+
+class GenericInferenceRunnerEthosU(GenericInferenceRunner):
+ """Generic inference runner on U55/65."""
+
+ def __init__(
+ self, aiet_runner: AIETRunner, device_info: DeviceInfo, backend: str
+ ) -> None:
+ """Init generic inference runner instance."""
+ super().__init__(aiet_runner)
+
+ system_name, app_name = self.resolve_system_and_app(device_info, backend)
+ self.system_name = system_name
+ self.app_name = app_name
+ self.device_info = device_info
+
+ @staticmethod
+ def resolve_system_and_app(
+ device_info: DeviceInfo, backend: str
+ ) -> Tuple[str, str]:
+ """Find appropriate system and application for the provided device/backend."""
+ try:
+ system_name = get_system_name(backend, device_info.device_type)
+ except KeyError as ex:
+ raise RuntimeError(
+ f"Unsupported device {device_info.device_type} "
+ f"for backend {backend}"
+ ) from ex
+
+ if system_name not in _SYSTEM_TO_APP_MAP:
+ raise RuntimeError(f"System {system_name} is not installed")
+
+ try:
+ app_name = _SYSTEM_TO_APP_MAP[system_name][device_info.memory_mode]
+ except KeyError as err:
+ raise RuntimeError(
+ f"Unsupported memory mode {device_info.memory_mode}"
+ ) from err
+
+ return system_name, app_name
+
+ def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams:
+ """Get execution params for Ethos-U55/65."""
+ self.check_system_and_application(self.system_name, self.app_name)
+
+ system_params = [
+ f"mac={self.device_info.mac}",
+ f"input_file={model_info.model_path.absolute()}",
+ ]
+
+ return ExecutionParams(
+ self.app_name,
+ self.system_name,
+ [],
+ system_params,
+ [],
+ )
+
+
+def get_generic_runner(device_info: DeviceInfo, backend: str) -> GenericInferenceRunner:
+ """Get generic runner for provided device and backend."""
+ aiet_runner = get_aiet_runner()
+ return GenericInferenceRunnerEthosU(aiet_runner, device_info, backend)
+
+
+def estimate_performance(
+ model_info: ModelInfo, device_info: DeviceInfo, backend: str
+) -> PerformanceMetrics:
+ """Get performance estimations."""
+ with get_generic_runner(device_info, backend) as generic_runner:
+ output_parser = GenericInferenceOutputParser()
+ output_consumers = [output_parser, AIETLogWriter()]
+
+ generic_runner.run(model_info, output_consumers)
+
+ if not output_parser.is_ready():
+ missed_data = ",".join(output_parser.missed_keys())
+ logger.debug(
+ "Unable to get performance metrics, missed data %s", missed_data
+ )
+ raise Exception("Unable to get performance metrics, insufficient data")
+
+ return PerformanceMetrics(**output_parser.result)
+
+
+def get_aiet_runner() -> AIETRunner:
+ """Return AIET runner."""
+ executor = CommandExecutor()
+ return AIETRunner(executor)
diff --git a/src/mlia/tools/metadata/__init__.py b/src/mlia/tools/metadata/__init__.py
new file mode 100644
index 0000000..f877e4f
--- /dev/null
+++ b/src/mlia/tools/metadata/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the tools metadata."""
diff --git a/src/mlia/tools/metadata/common.py b/src/mlia/tools/metadata/common.py
new file mode 100644
index 0000000..c17a738
--- /dev/null
+++ b/src/mlia/tools/metadata/common.py
@@ -0,0 +1,290 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for installation process."""
+import logging
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mlia.utils.misc import yes
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class InstallFromPath:
+ """Installation from the local path."""
+
+ backend_path: Path
+
+
+@dataclass
+class DownloadAndInstall:
+ """Download and install."""
+
+ eula_agreement: bool = True
+
+
+InstallationType = Union[InstallFromPath, DownloadAndInstall]
+
+
+class Installation(ABC):
+ """Base class for the installation process of the backends."""
+
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """Return name of the backend."""
+
+ @property
+ @abstractmethod
+ def description(self) -> str:
+ """Return description of the backend."""
+
+ @property
+ @abstractmethod
+ def could_be_installed(self) -> bool:
+ """Return true if backend could be installed in current environment."""
+
+ @property
+ @abstractmethod
+ def already_installed(self) -> bool:
+ """Return true if backend is already installed."""
+
+ @abstractmethod
+ def supports(self, install_type: InstallationType) -> bool:
+ """Return true if installation supports requested installation type."""
+
+ @abstractmethod
+ def install(self, install_type: InstallationType) -> None:
+ """Install the backend."""
+
+
+InstallationFilter = Callable[[Installation], bool]
+
+
+class AlreadyInstalledFilter:
+ """Filter for already installed backends."""
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return installation.already_installed
+
+
+class ReadyForInstallationFilter:
+ """Filter for ready to be installed backends."""
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return installation.could_be_installed and not installation.already_installed
+
+
+class SupportsInstallTypeFilter:
+ """Filter backends that support certain type of the installation."""
+
+ def __init__(self, installation_type: InstallationType) -> None:
+ """Init filter."""
+ self.installation_type = installation_type
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return installation.supports(self.installation_type)
+
+
+class SearchByNameFilter:
+ """Filter installation by name."""
+
+ def __init__(self, backend_name: Optional[str]) -> None:
+ """Init filter."""
+ self.backend_name = backend_name
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return not self.backend_name or installation.name == self.backend_name
+
+
+class InstallationManager(ABC):
+ """Helper class for managing installations."""
+
+ @abstractmethod
+ def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ """Install backend from the local directory."""
+
+ @abstractmethod
+ def download_and_install(
+ self, backend_name: Optional[str], eula_agreement: bool
+ ) -> None:
+ """Download and install backends."""
+
+ @abstractmethod
+ def show_env_details(self) -> None:
+ """Show environment details."""
+
+ @abstractmethod
+ def backend_installed(self, backend_name: str) -> bool:
+ """Return true if requested backend installed."""
+
+
+class InstallationFiltersMixin:
+ """Mixin for filtering installation based on different conditions."""
+
+ installations: List[Installation]
+
+ def filter_by(self, *filters: InstallationFilter) -> List[Installation]:
+ """Filter installations."""
+ return [
+ installation
+ for installation in self.installations
+ if all(filter_(installation) for filter_ in filters)
+ ]
+
+ def could_be_installed_from(
+ self, backend_path: Path, backend_name: Optional[str]
+ ) -> List[Installation]:
+ """Return installations that could be installed from provided directory."""
+ return self.filter_by(
+ SupportsInstallTypeFilter(InstallFromPath(backend_path)),
+ SearchByNameFilter(backend_name),
+ )
+
+ def could_be_downloaded_and_installed(
+ self, backend_name: Optional[str] = None
+ ) -> List[Installation]:
+ """Return installations that could be downloaded and installed."""
+ return self.filter_by(
+ SupportsInstallTypeFilter(DownloadAndInstall()),
+ SearchByNameFilter(backend_name),
+ ReadyForInstallationFilter(),
+ )
+
+ def already_installed(
+ self, backend_name: Optional[str] = None
+ ) -> List[Installation]:
+ """Return list of backends that are already installed."""
+ return self.filter_by(
+ AlreadyInstalledFilter(), SearchByNameFilter(backend_name)
+ )
+
+ def ready_for_installation(self) -> List[Installation]:
+ """Return list of the backends that could be installed."""
+ return self.filter_by(ReadyForInstallationFilter())
+
+
+class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
+ """Interactive installation manager."""
+
+ def __init__(
+ self, installations: List[Installation], noninteractive: bool = False
+ ) -> None:
+ """Init the manager."""
+ self.installations = installations
+ self.noninteractive = noninteractive
+
+ def choose_installation_for_path(
+ self, backend_path: Path, backend_name: Optional[str]
+ ) -> Optional[Installation]:
+ """Check available installation and select one if possible."""
+ installs = self.could_be_installed_from(backend_path, backend_name)
+
+ if not installs:
+ logger.info(
+ "Unfortunatelly, it was not possible to automatically "
+ "detect type of the installed FVP. "
+ "Please, check provided path to the installed FVP."
+ )
+ return None
+
+ if len(installs) != 1:
+ names = ",".join((install.name for install in installs))
+ logger.info(
+ "Unable to correctly detect type of the installed FVP."
+ "The following FVPs are detected %s. Installation skipped.",
+ names,
+ )
+ return None
+
+ installation = installs[0]
+ if installation.already_installed:
+ logger.info(
+ "%s was found in %s, but it has been already installed.",
+ installation.name,
+ backend_path,
+ )
+ return None
+
+ return installation
+
+ def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ """Install from the provided directory."""
+ installation = self.choose_installation_for_path(backend_path, backend_name)
+
+ if not installation:
+ return
+
+ prompt = (
+ f"{installation.name} was found in {backend_path}. "
+ "Would you like to install it?"
+ )
+ self._install(installation, InstallFromPath(backend_path), prompt)
+
+ def download_and_install(
+ self, backend_name: Optional[str] = None, eula_agreement: bool = True
+ ) -> None:
+ """Download and install available backends."""
+ installations = self.could_be_downloaded_and_installed(backend_name)
+
+ if not installations:
+ logger.info("No backends available for the installation.")
+ return
+
+ names = ",".join((installation.name for installation in installations))
+ logger.info("Following backends are available for downloading: %s", names)
+
+ for installation in installations:
+ prompt = f"Would you like to download and install {installation.name}?"
+ self._install(
+ installation, DownloadAndInstall(eula_agreement=eula_agreement), prompt
+ )
+
+ def show_env_details(self) -> None:
+ """Print current state of the execution environment."""
+ if installed := self.already_installed():
+ logger.info("Installed backends:\n")
+
+ for installation in installed:
+ logger.info(" - %s", installation.name)
+
+ if could_be_installed := self.ready_for_installation():
+ logger.info("Following backends could be installed:")
+
+ for installation in could_be_installed:
+ logger.info(" - %s", installation.name)
+
+ if not installed and not could_be_installed:
+ logger.info("No backends installed")
+
+ def _install(
+ self,
+ installation: Installation,
+ installation_type: InstallationType,
+ prompt: str,
+ ) -> None:
+ proceed = self.noninteractive or yes(prompt)
+
+ if proceed:
+ installation.install(installation_type)
+ logger.info("%s successfully installed.", installation.name)
+ else:
+ logger.info("%s installation canceled.", installation.name)
+
+ def backend_installed(self, backend_name: str) -> bool:
+ """Return true if requested backend installed."""
+ installations = self.already_installed(backend_name)
+
+ return len(installations) == 1
diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py
new file mode 100644
index 0000000..7a9d113
--- /dev/null
+++ b/src/mlia/tools/metadata/corstone.py
@@ -0,0 +1,402 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for Corstone based FVPs."""
+import logging
+import platform
+import subprocess
+import tarfile
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import mlia.tools.aiet_wrapper as aiet
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import Installation
+from mlia.tools.metadata.common import InstallationType
+from mlia.tools.metadata.common import InstallFromPath
+from mlia.utils.download import DownloadArtifact
+from mlia.utils.filesystem import all_files_exist
+from mlia.utils.filesystem import all_paths_valid
+from mlia.utils.filesystem import copy_all
+from mlia.utils.filesystem import get_mlia_resources
+from mlia.utils.filesystem import temp_directory
+from mlia.utils.proc import working_directory
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class BackendInfo:
+ """Backend information."""
+
+ backend_path: Path
+ copy_source: bool = True
+ system_config: Optional[str] = None
+
+
+PathChecker = Callable[[Path], Optional[BackendInfo]]
+BackendInstaller = Callable[[bool, Path], Path]
+
+
+class AIETMetadata:
+ """AIET installation metadata."""
+
+ def __init__(
+ self,
+ name: str,
+ description: str,
+ system_config: str,
+ apps_resources: List[str],
+ fvp_dir_name: str,
+ download_artifact: Optional[DownloadArtifact],
+ supported_platforms: Optional[List[str]] = None,
+ ) -> None:
+ """
+ Initialize AIETMetaData.
+
+ Members expected_systems and expected_apps are filled automatically.
+ """
+ self.name = name
+ self.description = description
+ self.system_config = system_config
+ self.apps_resources = apps_resources
+ self.fvp_dir_name = fvp_dir_name
+ self.download_artifact = download_artifact
+ self.supported_platforms = supported_platforms
+
+ self.expected_systems = aiet.get_all_system_names(name)
+ self.expected_apps = aiet.get_all_application_names(name)
+
+ @property
+ def expected_resources(self) -> Iterable[Path]:
+ """Return list of expected resources."""
+ resources = [self.system_config, *self.apps_resources]
+
+ return (get_mlia_resources() / resource for resource in resources)
+
+ @property
+ def supported_platform(self) -> bool:
+ """Return true if current platform supported."""
+ if not self.supported_platforms:
+ return True
+
+ return platform.system() in self.supported_platforms
+
+
+class AIETBasedInstallation(Installation):
+ """Backend installation based on AIET functionality."""
+
+ def __init__(
+ self,
+ aiet_runner: aiet.AIETRunner,
+ metadata: AIETMetadata,
+ path_checker: PathChecker,
+ backend_installer: Optional[BackendInstaller],
+ ) -> None:
+ """Init the tool installation."""
+ self.aiet_runner = aiet_runner
+ self.metadata = metadata
+ self.path_checker = path_checker
+ self.backend_installer = backend_installer
+
+ @property
+ def name(self) -> str:
+ """Return name of the tool."""
+ return self.metadata.name
+
+ @property
+ def description(self) -> str:
+ """Return description of the tool."""
+ return self.metadata.description
+
+ @property
+ def already_installed(self) -> bool:
+ """Return true if tool already installed."""
+ return self.aiet_runner.all_installed(
+ self.metadata.expected_systems, self.metadata.expected_apps
+ )
+
+ @property
+ def could_be_installed(self) -> bool:
+ """Return true if tool could be installed."""
+ if not self.metadata.supported_platform:
+ return False
+
+ return all_paths_valid(self.metadata.expected_resources)
+
+ def supports(self, install_type: InstallationType) -> bool:
+ """Return true if tools supported type of the installation."""
+ if isinstance(install_type, DownloadAndInstall):
+ return self.metadata.download_artifact is not None
+
+ if isinstance(install_type, InstallFromPath):
+ return self.path_checker(install_type.backend_path) is not None
+
+ return False # type: ignore
+
+ def install(self, install_type: InstallationType) -> None:
+ """Install the tool."""
+ if isinstance(install_type, DownloadAndInstall):
+ download_artifact = self.metadata.download_artifact
+ assert download_artifact is not None, "No artifact provided"
+
+ self.download_and_install(download_artifact, install_type.eula_agreement)
+ elif isinstance(install_type, InstallFromPath):
+ backend_path = self.path_checker(install_type.backend_path)
+ assert backend_path is not None, "Unable to resolve backend path"
+
+ self.install_from(backend_path)
+ else:
+ raise Exception(f"Unable to install {install_type}")
+
+ def install_from(self, backend_info: BackendInfo) -> None:
+ """Install tool from the directory."""
+ mlia_resources = get_mlia_resources()
+
+ with temp_directory() as tmpdir:
+ fvp_dist_dir = tmpdir / self.metadata.fvp_dir_name
+
+ system_config = self.metadata.system_config
+ if backend_info.system_config:
+ system_config = backend_info.system_config
+
+ resources_to_copy = [mlia_resources / system_config]
+ if backend_info.copy_source:
+ resources_to_copy.append(backend_info.backend_path)
+
+ copy_all(*resources_to_copy, dest=fvp_dist_dir)
+
+ self.aiet_runner.install_system(fvp_dist_dir)
+
+ for app in self.metadata.apps_resources:
+ self.aiet_runner.install_application(mlia_resources / app)
+
+ def download_and_install(
+ self, download_artifact: DownloadArtifact, eula_agrement: bool
+ ) -> None:
+ """Download and install the tool."""
+ with temp_directory() as tmpdir:
+ try:
+ downloaded_to = download_artifact.download_to(tmpdir)
+ except Exception as err:
+ raise Exception("Unable to download backend artifact") from err
+
+ with working_directory(tmpdir / "dist", create_dir=True) as dist_dir:
+ with tarfile.open(downloaded_to) as archive:
+ archive.extractall(dist_dir)
+
+ assert self.backend_installer, (
+ f"Backend '{self.metadata.name}' does not support "
+ "download and installation."
+ )
+ backend_path = self.backend_installer(eula_agrement, dist_dir)
+ if self.path_checker(backend_path) is None:
+ raise Exception("Downloaded artifact has invalid structure")
+
+ self.install(InstallFromPath(backend_path))
+
+
+class PackagePathChecker:
+ """Package path checker."""
+
+ def __init__(
+ self, expected_files: List[str], backend_subfolder: Optional[str] = None
+ ) -> None:
+ """Init the path checker."""
+ self.expected_files = expected_files
+ self.backend_subfolder = backend_subfolder
+
+ def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ """Check if directory contains all expected files."""
+ resolved_paths = (backend_path / file for file in self.expected_files)
+ if not all_files_exist(resolved_paths):
+ return None
+
+ if self.backend_subfolder:
+ subfolder = backend_path / self.backend_subfolder
+
+ if not subfolder.is_dir():
+ return None
+
+ return BackendInfo(subfolder)
+
+ return BackendInfo(backend_path)
+
+
+class StaticPathChecker:
+ """Static path checker."""
+
+ def __init__(
+ self,
+ static_backend_path: Path,
+ expected_files: List[str],
+ copy_source: bool = False,
+ system_config: Optional[str] = None,
+ ) -> None:
+ """Init static path checker."""
+ self.static_backend_path = static_backend_path
+ self.expected_files = expected_files
+ self.copy_source = copy_source
+ self.system_config = system_config
+
+ def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ """Check if directory equals static backend path with all expected files."""
+ if backend_path != self.static_backend_path:
+ return None
+
+ resolved_paths = (backend_path / file for file in self.expected_files)
+ if not all_files_exist(resolved_paths):
+ return None
+
+ return BackendInfo(
+ backend_path,
+ copy_source=self.copy_source,
+ system_config=self.system_config,
+ )
+
+
+class CompoundPathChecker:
+ """Compound path checker."""
+
+ def __init__(self, *path_checkers: PathChecker) -> None:
+ """Init compound path checker."""
+ self.path_checkers = path_checkers
+
+ def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ """Iterate over checkers and return first non empty backend info."""
+ first_resolved_backend_info = (
+ backend_info
+ for path_checker in self.path_checkers
+ if (backend_info := path_checker(backend_path)) is not None
+ )
+
+ return next(first_resolved_backend_info, None)
+
+
+class Corstone300Installer:
+ """Helper class that wraps Corstone 300 installation logic."""
+
+ def __call__(self, eula_agreement: bool, dist_dir: Path) -> Path:
+ """Install Corstone-300 and return path to the models."""
+ with working_directory(dist_dir):
+ install_dir = "corstone-300"
+ try:
+ fvp_install_cmd = [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ install_dir,
+ ]
+ if not eula_agreement:
+ fvp_install_cmd += [
+ "--nointeractive",
+ "--i-agree-to-the-contained-eula",
+ ]
+
+ subprocess.check_call(fvp_install_cmd)
+ except subprocess.CalledProcessError as err:
+ raise Exception(
+ "Error occurred during Corstone-300 installation"
+ ) from err
+
+ return dist_dir / install_dir
+
+
+def get_corstone_300_installation() -> Installation:
+ """Get Corstone-300 installation."""
+ corstone_300 = AIETBasedInstallation(
+ aiet_runner=aiet.get_aiet_runner(),
+ # pylint: disable=line-too-long
+ metadata=AIETMetadata(
+ name="Corstone-300",
+ description="Corstone-300 FVP",
+ system_config="aiet/systems/corstone-300/aiet-config.json",
+ apps_resources=[
+ "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA",
+ "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA",
+ "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA",
+ ],
+ fvp_dir_name="corstone_300",
+ download_artifact=DownloadArtifact(
+ name="Corstone-300 FVP",
+ url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.16_26.tgz",
+ filename="FVP_Corstone_SSE-300_11.16_26.tgz",
+ version="11.16_26",
+ sha256_hash="e26139be756b5003a30d978c629de638aed1934d597dc24a17043d4708e934d7",
+ ),
+ supported_platforms=["Linux"],
+ ),
+ # pylint: enable=line-too-long
+ path_checker=CompoundPathChecker(
+ PackagePathChecker(
+ expected_files=[
+ "models/Linux64_GCC-6.4/FVP_Corstone_SSE-300_Ethos-U55",
+ "models/Linux64_GCC-6.4/FVP_Corstone_SSE-300_Ethos-U65",
+ ],
+ backend_subfolder="models/Linux64_GCC-6.4",
+ ),
+ StaticPathChecker(
+ static_backend_path=Path("/opt/VHT"),
+ expected_files=[
+ "VHT_Corstone_SSE-300_Ethos-U55",
+ "VHT_Corstone_SSE-300_Ethos-U65",
+ ],
+ copy_source=False,
+ system_config="aiet/systems/corstone-300-vht/aiet-config.json",
+ ),
+ ),
+ backend_installer=Corstone300Installer(),
+ )
+
+ return corstone_300
+
+
+def get_corstone_310_installation() -> Installation:
+ """Get Corstone-310 installation."""
+ corstone_310 = AIETBasedInstallation(
+ aiet_runner=aiet.get_aiet_runner(),
+ # pylint: disable=line-too-long
+ metadata=AIETMetadata(
+ name="Corstone-310",
+ description="Corstone-310 FVP",
+ system_config="aiet/systems/corstone-310/aiet-config.json",
+ apps_resources=[
+ "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA",
+ "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA",
+ ],
+ fvp_dir_name="corstone_310",
+ download_artifact=None,
+ supported_platforms=["Linux"],
+ ),
+ # pylint: enable=line-too-long
+ path_checker=CompoundPathChecker(
+ PackagePathChecker(
+ expected_files=[
+ "models/Linux64_GCC-9.3/FVP_Corstone_SSE-310",
+ ],
+ backend_subfolder="models/Linux64_GCC-9.3",
+ ),
+ StaticPathChecker(
+ static_backend_path=Path("/opt/VHT"),
+ expected_files=[
+ "VHT_Corstone_SSE-310",
+ ],
+ copy_source=False,
+ system_config="aiet/systems/corstone-310-vht/aiet-config.json",
+ ),
+ ),
+ backend_installer=None,
+ )
+
+ return corstone_310
+
+
+def get_corstone_installations() -> List[Installation]:
+ """Get Corstone installations."""
+ return [
+ get_corstone_300_installation(),
+ get_corstone_310_installation(),
+ ]
diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py
new file mode 100644
index 0000000..7225797
--- /dev/null
+++ b/src/mlia/tools/vela_wrapper.py
@@ -0,0 +1,500 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Vela wrapper module."""
+import itertools
+import logging
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+from ethosu.vela.architecture_features import ArchitectureFeatures
+from ethosu.vela.compiler_driver import compiler_driver
+from ethosu.vela.compiler_driver import CompilerOptions
+from ethosu.vela.compiler_driver import TensorAllocator
+from ethosu.vela.model_reader import ModelReaderOptions
+from ethosu.vela.model_reader import read_model
+from ethosu.vela.nn_graph import Graph
+from ethosu.vela.nn_graph import NetworkType
+from ethosu.vela.npu_performance import PassCycles
+from ethosu.vela.operation import CustomType
+from ethosu.vela.operation import Op
+from ethosu.vela.scheduler import OptimizationStrategy
+from ethosu.vela.scheduler import SchedulerOptions
+from ethosu.vela.tensor import BandwidthDirection
+from ethosu.vela.tensor import MemArea
+from ethosu.vela.tensor import Tensor
+from ethosu.vela.tflite_mapping import optype_to_builtintype
+from ethosu.vela.tflite_model_semantic import TFLiteSemantic
+from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators
+from ethosu.vela.tflite_writer import write_tflite
+from ethosu.vela.vela import generate_supported_ops
+
+from mlia.utils.logging import redirect_output
+
+
+logger = logging.getLogger(__name__)
+
+VELA_INTERNAL_OPS = (Op.Placeholder, Op.SubgraphInput, Op.Const)
+
+
+@dataclass
+class PerformanceMetrics: # pylint: disable=too-many-instance-attributes
+ """Contains all the performance metrics Vela generates in a run."""
+
+ npu_cycles: int
+ sram_access_cycles: int
+ dram_access_cycles: int
+ on_chip_flash_access_cycles: int
+ off_chip_flash_access_cycles: int
+ total_cycles: int
+ batch_inference_time: float
+ inferences_per_second: float
+ batch_size: int
+ unknown_memory_area_size: int
+ sram_memory_area_size: int
+ dram_memory_area_size: int
+ on_chip_flash_memory_area_size: int
+ off_chip_flash_memory_area_size: int
+
+
+@dataclass
+class NpuSupported:
+ """Operator's npu supported attribute."""
+
+ supported: bool
+ reasons: List[Tuple[str, str]]
+
+
+@dataclass
+class Operator:
+ """Model operator."""
+
+ name: str
+ op_type: str
+ run_on_npu: NpuSupported
+
+ @property
+ def cpu_only(self) -> bool:
+ """Return true if operator is CPU only."""
+ cpu_only_reasons = [("CPU only operator", "")]
+ return (
+ not self.run_on_npu.supported
+ and self.run_on_npu.reasons == cpu_only_reasons
+ )
+
+
+@dataclass
+class Operators:
+ """Model's operators."""
+
+ ops: List[Operator]
+
+ @property
+ def npu_supported_ratio(self) -> float:
+ """Return NPU supported ratio."""
+ total = self.total_number
+ npu_supported = self.npu_supported_number
+
+ if total == 0 or npu_supported == 0:
+ return 0
+
+ return npu_supported / total
+
+ @property
+ def npu_unsupported_ratio(self) -> float:
+ """Return NPU unsupported ratio."""
+ return 1 - self.npu_supported_ratio
+
+ @property
+ def total_number(self) -> int:
+ """Return total number of operators."""
+ return len(self.ops)
+
+ @property
+ def npu_supported_number(self) -> int:
+ """Return number of npu supported operators."""
+ return sum(op.run_on_npu.supported for op in self.ops)
+
+
+@dataclass
+class Model:
+ """Model metadata."""
+
+ nng: Graph
+ network_type: NetworkType
+
+ @property
+ def optimized(self) -> bool:
+ """Return true if model is already optimized."""
+ return any(
+ op.attrs.get("custom_type") == CustomType.ExistingNpuOp
+ for sg in self.nng.subgraphs
+ for op in sg.get_all_ops()
+ )
+
+
+@dataclass
+class OptimizedModel:
+ """Instance of the Vela optimized model."""
+
+ nng: Graph
+ arch: ArchitectureFeatures
+ compiler_options: CompilerOptions
+ scheduler_options: SchedulerOptions
+
+ def save(self, output_filename: Union[str, Path]) -> None:
+ """Save instance of the optimized model to the file."""
+ write_tflite(self.nng, output_filename)
+
+
+AcceleratorConfigType = Literal[
+ "ethos-u55-32",
+ "ethos-u55-64",
+ "ethos-u55-128",
+ "ethos-u55-256",
+ "ethos-u65-256",
+ "ethos-u65-512",
+]
+
+TensorAllocatorType = Literal["LinearAlloc", "Greedy", "HillClimb"]
+
+OptimizationStrategyType = Literal["Performance", "Size"]
+
+
+@dataclass
+class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes
+ """Vela compiler options."""
+
+ config_files: Optional[Union[str, List[str]]] = None
+ system_config: str = ArchitectureFeatures.DEFAULT_CONFIG
+ memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG
+ accelerator_config: Optional[AcceleratorConfigType] = None
+ max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP
+ arena_cache_size: Optional[int] = None
+ tensor_allocator: TensorAllocatorType = "HillClimb"
+ cpu_tensor_alignment: int = Tensor.AllocationQuantum
+ optimization_strategy: OptimizationStrategyType = "Performance"
+ output_dir: Optional[str] = None
+ recursion_limit: int = 1000
+
+
+class VelaCompiler: # pylint: disable=too-many-instance-attributes
+ """Vela compiler wrapper."""
+
+ def __init__(self, compiler_options: VelaCompilerOptions):
+ """Init Vela wrapper instance."""
+ self.config_files = compiler_options.config_files
+ self.system_config = compiler_options.system_config
+ self.memory_mode = compiler_options.memory_mode
+ self.accelerator_config = compiler_options.accelerator_config
+ self.max_block_dependency = compiler_options.max_block_dependency
+ self.arena_cache_size = compiler_options.arena_cache_size
+ self.tensor_allocator = TensorAllocator[compiler_options.tensor_allocator]
+ self.cpu_tensor_alignment = compiler_options.cpu_tensor_alignment
+ self.optimization_strategy = OptimizationStrategy[
+ compiler_options.optimization_strategy
+ ]
+ self.output_dir = compiler_options.output_dir
+ self.recursion_limit = compiler_options.recursion_limit
+
+ sys.setrecursionlimit(self.recursion_limit)
+
+ def read_model(self, model: Union[str, Path]) -> Model:
+ """Read model."""
+ logger.debug("Read model %s", model)
+
+ nng, network_type = self._read_model(model)
+ return Model(nng, network_type)
+
+ def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel:
+ """Compile the model."""
+ if isinstance(model, (str, Path)):
+ nng, network_type = self._read_model(model)
+ else:
+ nng, network_type = model.nng, NetworkType.TFLite
+
+ if not nng:
+ raise Exception("Unable to read model")
+
+ try:
+ arch = self._architecture_features()
+ compiler_options = self._compiler_options()
+ scheduler_options = self._scheduler_options()
+
+ with redirect_output(
+ logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG
+ ):
+ compiler_driver(
+ nng, arch, compiler_options, scheduler_options, network_type
+ )
+
+ return OptimizedModel(nng, arch, compiler_options, scheduler_options)
+ except (SystemExit, Exception) as err:
+ raise Exception("Model could not be optimized with Vela compiler") from err
+
+ def get_config(self) -> Dict[str, Any]:
+ """Get compiler configuration."""
+ arch = self._architecture_features()
+
+ memory_area = {
+ mem.name: {
+ "clock_scales": arch.memory_clock_scales[mem],
+ "burst_length": arch.memory_burst_length[mem],
+ "read_latency": arch.memory_latency[mem][BandwidthDirection.Read],
+ "write_latency": arch.memory_latency[mem][BandwidthDirection.Write],
+ }
+ for mem in (
+ MemArea.Sram,
+ MemArea.Dram,
+ MemArea.OnChipFlash,
+ MemArea.OffChipFlash,
+ )
+ }
+
+ return {
+ "accelerator_config": arch.accelerator_config.value,
+ "system_config": arch.system_config,
+ "core_clock": arch.core_clock,
+ "axi0_port": arch.axi0_port.name,
+ "axi1_port": arch.axi1_port.name,
+ "memory_mode": arch.memory_mode,
+ "const_mem_area": arch.const_mem_area.name,
+ "arena_mem_area": arch.arena_mem_area.name,
+ "cache_mem_area": arch.cache_mem_area.name,
+ "arena_cache_size": arch.arena_cache_size,
+ "permanent_storage_mem_area": arch.permanent_storage_mem_area.name,
+ "feature_map_storage_mem_area": arch.feature_map_storage_mem_area.name,
+ "fast_storage_mem_area": arch.fast_storage_mem_area.name,
+ "memory_area": memory_area,
+ }
+
+ @staticmethod
+ def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]:
+ """Read TFLite model."""
+ try:
+ model_path = str(model) if isinstance(model, Path) else model
+
+ with redirect_output(
+ logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG
+ ):
+ return read_model(model_path, ModelReaderOptions()) # type: ignore
+ except (SystemExit, Exception) as err:
+ raise Exception(f"Unable to read model {model_path}") from err
+
+ def _architecture_features(self) -> ArchitectureFeatures:
+ """Return ArchitectureFeatures instance."""
+ return ArchitectureFeatures(
+ vela_config_files=self.config_files,
+ accelerator_config=self.accelerator_config,
+ system_config=self.system_config,
+ memory_mode=self.memory_mode,
+ max_blockdep=self.max_block_dependency,
+ verbose_config=False,
+ arena_cache_size=self.arena_cache_size,
+ )
+
+ def _scheduler_options(self) -> SchedulerOptions:
+ """Return SchedulerOptions instance."""
+ arch = self._architecture_features()
+
+ return SchedulerOptions(
+ optimization_strategy=self.optimization_strategy,
+ sram_target=arch.arena_cache_size,
+ verbose_schedule=False,
+ )
+
+ def _compiler_options(self) -> CompilerOptions:
+ """Return CompilerOptions instance."""
+ return CompilerOptions(
+ verbose_graph=False,
+ verbose_quantization=False,
+ verbose_packing=False,
+ verbose_tensor_purpose=False,
+ verbose_tensor_format=False,
+ verbose_allocation=False,
+ verbose_high_level_command_stream=False,
+ verbose_register_command_stream=False,
+ verbose_operators=False,
+ verbose_weights=False,
+ show_cpu_operations=False,
+ tensor_allocator=self.tensor_allocator,
+ timing=False,
+ output_dir=self.output_dir,
+ cpu_tensor_alignment=self.cpu_tensor_alignment,
+ )
+
+
+def resolve_compiler_config(
+ vela_compiler_options: VelaCompilerOptions,
+) -> Dict[str, Any]:
+ """Resolve passed compiler options.
+
+ Vela has number of configuration parameters that being
+ resolved during passing compiler options. E.g. Vela
+ reads configuration parameters from vela.ini and fills
+ it's internal structures with resolved values (memory mode,
+ system mode, etc.).
+
+ In order to get this information we need to create
+ instance of the Vela compiler first.
+ """
+ vela_compiler = VelaCompiler(vela_compiler_options)
+ return vela_compiler.get_config()
+
+
+def estimate_performance(
+ model_path: Path, compiler_options: VelaCompilerOptions
+) -> PerformanceMetrics:
+ """Return performance estimations for the model/device.
+
+ Logic for this function comes from Vela module stats_writer.py
+ """
+ logger.debug(
+ "Estimate performance for the model %s on %s",
+ model_path,
+ compiler_options.accelerator_config,
+ )
+
+ vela_compiler = VelaCompiler(compiler_options)
+
+ initial_model = vela_compiler.read_model(model_path)
+ if initial_model.optimized:
+ raise Exception("Unable to estimate performance for the given optimized model")
+
+ optimized_model = vela_compiler.compile_model(initial_model)
+
+ return _performance_metrics(optimized_model)
+
+
+def optimize_model(
+ model_path: Path, compiler_options: VelaCompilerOptions, output_model_path: Path
+) -> None:
+ """Optimize model and return it's path after optimization."""
+ logger.debug(
+ "Optimize model %s for device %s",
+ model_path,
+ compiler_options.accelerator_config,
+ )
+
+ vela_compiler = VelaCompiler(compiler_options)
+ optimized_model = vela_compiler.compile_model(model_path)
+
+ logger.debug("Save optimized model into %s", output_model_path)
+ optimized_model.save(output_model_path)
+
+
+def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics:
+ """Return performance metrics for optimized model."""
+ cycles = optimized_model.nng.cycles
+
+ def memory_usage(mem_area: MemArea) -> int:
+ """Get memory usage for the proviced memory area type."""
+ memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used
+ bandwidths = optimized_model.nng.bandwidths
+
+ return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0
+
+ midpoint_fps = np.nan
+ midpoint_inference_time = cycles[PassCycles.Total] / optimized_model.arch.core_clock
+ if midpoint_inference_time > 0:
+ midpoint_fps = 1 / midpoint_inference_time
+
+ return PerformanceMetrics(
+ npu_cycles=int(cycles[PassCycles.Npu]),
+ sram_access_cycles=int(cycles[PassCycles.SramAccess]),
+ dram_access_cycles=int(cycles[PassCycles.DramAccess]),
+ on_chip_flash_access_cycles=int(cycles[PassCycles.OnChipFlashAccess]),
+ off_chip_flash_access_cycles=int(cycles[PassCycles.OffChipFlashAccess]),
+ total_cycles=int(cycles[PassCycles.Total]),
+ batch_inference_time=midpoint_inference_time * 1000,
+ inferences_per_second=midpoint_fps,
+ batch_size=optimized_model.nng.batch_size,
+ unknown_memory_area_size=memory_usage(MemArea.Unknown),
+ sram_memory_area_size=memory_usage(MemArea.Sram),
+ dram_memory_area_size=memory_usage(MemArea.Dram),
+ on_chip_flash_memory_area_size=memory_usage(MemArea.OnChipFlash),
+ off_chip_flash_memory_area_size=memory_usage(MemArea.OffChipFlash),
+ )
+
+
+def supported_operators(
+ model_path: Path, compiler_options: VelaCompilerOptions
+) -> Operators:
+ """Return list of model's operators."""
+ logger.debug("Check supported operators for the model %s", model_path)
+
+ vela_compiler = VelaCompiler(compiler_options)
+ initial_model = vela_compiler.read_model(model_path)
+
+ return Operators(
+ [
+ Operator(op.name, optype_to_builtintype(op.type), run_on_npu(op))
+ for sg in initial_model.nng.subgraphs
+ for op in sg.get_all_ops()
+ if op.type not in VELA_INTERNAL_OPS
+ ]
+ )
+
+
+def run_on_npu(operator: Op) -> NpuSupported:
+ """Return information if operator can run on NPU.
+
+ Vela does a number of checks that can help establish whether
+ a particular operator is supported to run on NPU.
+
+ There are two groups of checks:
+ - general TFLite constraints
+ - operator specific constraints
+
+ If an operator is not supported on NPU then this function
+ will return the reason of that.
+
+ The reason is split in two parts:
+ - general description of why the operator cannot be placed on NPU
+ - details on the particular operator
+ """
+ semantic_checker = TFLiteSemantic()
+ semantic_constraints = itertools.chain(
+ semantic_checker.generic_constraints,
+ semantic_checker.specific_constraints[operator.type],
+ )
+
+ for constraint in semantic_constraints:
+ op_valid, op_reason = constraint(operator)
+ if not op_valid:
+ return NpuSupported(False, [(constraint.__doc__, op_reason)])
+
+ if operator.type not in TFLiteSupportedOperators.supported_operators:
+ reasons = (
+ [("CPU only operator", "")]
+ if operator.type not in VELA_INTERNAL_OPS
+ else []
+ )
+
+ return NpuSupported(False, reasons)
+
+ tflite_supported_operators = TFLiteSupportedOperators()
+ operation_constraints = itertools.chain(
+ tflite_supported_operators.generic_constraints,
+ tflite_supported_operators.specific_constraints[operator.type],
+ )
+ for constraint in operation_constraints:
+ op_valid, op_reason = constraint(operator)
+ if not op_valid:
+ return NpuSupported(False, [(constraint.__doc__, op_reason)])
+
+ return NpuSupported(True, [])
+
+
+def generate_supported_operators_report() -> None:
+ """Generate supported operators report in current working directory."""
+ with redirect_output(logger):
+ generate_supported_ops()
diff --git a/src/mlia/utils/__init__.py b/src/mlia/utils/__init__.py
new file mode 100644
index 0000000..ecb5ca1
--- /dev/null
+++ b/src/mlia/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils module."""
diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py
new file mode 100644
index 0000000..7cb3d83
--- /dev/null
+++ b/src/mlia/utils/console.py
@@ -0,0 +1,97 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Console output utility functions."""
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+from rich.console import Console
+from rich.console import RenderableType
+from rich.table import box
+from rich.table import Table
+from rich.text import Text
+
+
+def create_section_header(
+ section_name: Optional[str] = None, length: int = 80, sep: str = "-"
+) -> str:
+ """Return section header."""
+ if not section_name:
+ content = sep * length
+ else:
+ before = 3
+ spaces = 2
+ after = length - (len(section_name) + before + spaces)
+ if after < 0:
+ raise ValueError("Section name too long")
+ content = f"{sep * before} {section_name} {sep * after}"
+
+ return f"\n{content}\n"
+
+
+def apply_style(value: str, style: str) -> str:
+ """Apply style to the value."""
+ return f"[{style}]{value}"
+
+
+def style_improvement(result: bool) -> str:
+ """Return different text style based on result."""
+ return "green" if result else "yellow"
+
+
+def produce_table(
+ rows: Iterable,
+ headers: Optional[List[str]] = None,
+ table_style: str = "default",
+) -> str:
+ """Represent data in tabular form."""
+ table = _get_table(table_style)
+
+ if headers:
+ table.show_header = True
+ for header in headers:
+ table.add_column(header)
+
+ for row in rows:
+ table.add_row(*row)
+
+ return _convert_to_text(table)
+
+
+def _get_table(table_style: str) -> Table:
+ """Get Table instance for the provided style."""
+ if table_style == "default":
+ return Table(
+ show_header=False,
+ show_lines=True,
+ box=box.SQUARE_DOUBLE_HEAD,
+ )
+
+ if table_style == "nested":
+ return Table(
+ show_header=False,
+ box=None,
+ padding=(0, 1, 1, 0),
+ )
+
+ if table_style == "no_borders":
+ return Table(show_header=False, box=None)
+
+ raise Exception(f"Unsupported table style {table_style}")
+
+
+def _convert_to_text(*renderables: RenderableType) -> str:
+ """Convert renderable object to text."""
+ console = Console()
+ with console.capture() as capture:
+ for item in renderables:
+ console.print(item)
+
+ text = capture.get()
+ return text.rstrip()
+
+
+def remove_ascii_codes(value: str) -> str:
+ """Decode and remove ASCII codes."""
+ text = Text.from_ansi(value)
+ return text.plain
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
new file mode 100644
index 0000000..4658738
--- /dev/null
+++ b/src/mlia/utils/download.py
@@ -0,0 +1,89 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils for files downloading."""
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import requests
+from rich.progress import BarColumn
+from rich.progress import DownloadColumn
+from rich.progress import FileSizeColumn
+from rich.progress import Progress
+from rich.progress import ProgressColumn
+from rich.progress import TextColumn
+
+from mlia.utils.filesystem import sha256
+from mlia.utils.types import parse_int
+
+
+def download_progress(
+ content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str]
+) -> Iterable[bytes]:
+ """Show progress info while reading content."""
+ columns: List[ProgressColumn] = [TextColumn("{task.description}")]
+
+ if content_length is None:
+ total = float("inf")
+ columns.append(FileSizeColumn())
+ else:
+ total = content_length
+ columns.extend([BarColumn(), DownloadColumn(binary_units=True)])
+
+ with Progress(*columns) as progress:
+ task = progress.add_task(label or "Downloading", total=total)
+
+ for chunk in content_chunks:
+ progress.update(task, advance=len(chunk))
+ yield chunk
+
+
+def download(
+ url: str,
+ dest: Path,
+ show_progress: bool = False,
+ label: Optional[str] = None,
+ chunk_size: int = 8192,
+) -> None:
+ """Download the file."""
+ with requests.get(url, stream=True) as resp:
+ resp.raise_for_status()
+ content_chunks = resp.iter_content(chunk_size=chunk_size)
+
+ if show_progress:
+ content_length = parse_int(resp.headers.get("Content-Length"))
+ content_chunks = download_progress(content_chunks, content_length, label)
+
+ with open(dest, "wb") as file:
+ for chunk in content_chunks:
+ file.write(chunk)
+
+
+@dataclass
+class DownloadArtifact:
+ """Download artifact attributes."""
+
+ name: str
+ url: str
+ filename: str
+ version: str
+ sha256_hash: str
+
+ def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path:
+ """Download artifact into destination directory."""
+ if (dest := dest_dir / self.filename).exists():
+ raise ValueError(f"{dest} already exists")
+
+ download(
+ self.url,
+ dest,
+ show_progress=show_progress,
+ label=f"Downloading {self.name} ver. {self.version}",
+ )
+
+ if sha256(dest) != self.sha256_hash:
+ raise ValueError("Digests do not match")
+
+ return dest
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
new file mode 100644
index 0000000..73a88d9
--- /dev/null
+++ b/src/mlia/utils/filesystem.py
@@ -0,0 +1,124 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils related to file management."""
+import hashlib
+import importlib.resources as pkg_resources
+import json
+import os
+import shutil
+from contextlib import contextmanager
+from pathlib import Path
+from tempfile import mkstemp
+from tempfile import TemporaryDirectory
+from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Union
+
+
+def get_mlia_resources() -> Path:
+ """Get the path to the resources directory."""
+ with pkg_resources.path("mlia", "__init__.py") as init_path:
+ project_root = init_path.parent
+ return project_root / "resources"
+
+
+def get_vela_config() -> Path:
+ """Get the path to the default Vela config file."""
+ return get_mlia_resources() / "vela/vela.ini"
+
+
+def get_profiles_file() -> Path:
+ """Get the Ethos-U profiles file."""
+ return get_mlia_resources() / "profiles.json"
+
+
+def get_profiles_data() -> Dict[str, Dict[str, Any]]:
+ """Get the Ethos-U profile values as a dictionary."""
+ with open(get_profiles_file(), encoding="utf-8") as json_file:
+ profiles = json.load(json_file)
+
+ if not isinstance(profiles, dict):
+ raise Exception("Profiles data format is not valid")
+
+ return profiles
+
+
+def get_profile(target: str) -> Dict[str, Any]:
+ """Get settings for the provided target profile."""
+ profiles = get_profiles_data()
+
+ if target not in profiles:
+ raise Exception(f"Unable to find target profile {target}")
+
+ return profiles[target]
+
+
+def get_supported_profile_names() -> List[str]:
+ """Get the supported Ethos-U profile names."""
+ return list(get_profiles_data().keys())
+
+
+@contextmanager
+def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+ """Create temp file and remove it after."""
+ _, tmp_file = mkstemp(suffix=suffix)
+
+ try:
+ yield Path(tmp_file)
+ finally:
+ os.remove(tmp_file)
+
+
+@contextmanager
+def temp_directory(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+ """Create temp directory and remove it after."""
+ with TemporaryDirectory(suffix=suffix) as tmpdir:
+ yield Path(tmpdir)
+
+
+def file_chunks(
+ filepath: Union[Path, str], chunk_size: int = 4096
+) -> Generator[bytes, None, None]:
+ """Return sequence of the file chunks."""
+ with open(filepath, "rb") as file:
+ while data := file.read(chunk_size):
+ yield data
+
+
+def hexdigest(filepath: Union[Path, str], hash_obj: "hashlib._Hash") -> str:
+ """Return hex digest of the file."""
+ for chunk in file_chunks(filepath):
+ hash_obj.update(chunk)
+
+ return hash_obj.hexdigest()
+
+
+def sha256(filepath: Path) -> str:
+ """Return SHA256 hash of the file."""
+ return hexdigest(filepath, hashlib.sha256())
+
+
+def all_files_exist(paths: Iterable[Path]) -> bool:
+ """Check if all files are exist."""
+ return all(item.is_file() for item in paths)
+
+
+def all_paths_valid(paths: Iterable[Path]) -> bool:
+ """Check if all paths are valid."""
+ return all(item.exists() for item in paths)
+
+
+def copy_all(*paths: Path, dest: Path) -> None:
+ """Copy files/directories into destination folder."""
+ dest.mkdir(exist_ok=True)
+
+ for path in paths:
+ if path.is_file():
+ shutil.copy2(path, dest)
+
+ if path.is_dir():
+ shutil.copytree(path, dest, dirs_exist_ok=True)
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
new file mode 100644
index 0000000..86d7567
--- /dev/null
+++ b/src/mlia/utils/logging.py
@@ -0,0 +1,120 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Logging utility functions."""
+import logging
+from contextlib import contextmanager
+from contextlib import ExitStack
+from contextlib import redirect_stderr
+from contextlib import redirect_stdout
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Generator
+from typing import List
+from typing import Optional
+
+
+class LoggerWriter:
+ """Redirect printed messages to the logger."""
+
+ def __init__(self, logger: logging.Logger, level: int):
+ """Init logger writer."""
+ self.logger = logger
+ self.level = level
+
+ def write(self, message: str) -> None:
+ """Write message."""
+ if message.strip() != "":
+ self.logger.log(self.level, message)
+
+ def flush(self) -> None:
+ """Flush buffers."""
+
+
+@contextmanager
+def redirect_output(
+ logger: logging.Logger,
+ stdout_level: int = logging.INFO,
+ stderr_level: int = logging.INFO,
+) -> Generator[None, None, None]:
+ """Redirect standard output to the logger."""
+ stdout_to_log = LoggerWriter(logger, stdout_level)
+ stderr_to_log = LoggerWriter(logger, stderr_level)
+
+ with ExitStack() as exit_stack:
+ exit_stack.enter_context(redirect_stdout(stdout_to_log)) # type: ignore
+ exit_stack.enter_context(redirect_stderr(stderr_to_log)) # type: ignore
+
+ yield
+
+
+class LogFilter(logging.Filter):
+ """Configurable log filter."""
+
+ def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None:
+ """Init log filter instance."""
+ super().__init__()
+ self.log_record_filter = log_record_filter
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ """Filter log messages."""
+ return self.log_record_filter(record)
+
+ @classmethod
+ def equals(cls, log_level: int) -> "LogFilter":
+ """Return log filter that filters messages by log level."""
+
+ def filter_by_level(log_record: logging.LogRecord) -> bool:
+ return log_record.levelno == log_level
+
+ return cls(filter_by_level)
+
+ @classmethod
+ def skip(cls, log_level: int) -> "LogFilter":
+ """Return log filter that skips messages with particular level."""
+
+ def skip_by_level(log_record: logging.LogRecord) -> bool:
+ return log_record.levelno != log_level
+
+ return cls(skip_by_level)
+
+
+def create_log_handler(
+ *,
+ file_path: Optional[Path] = None,
+ stream: Optional[Any] = None,
+ log_level: Optional[int] = None,
+ log_format: Optional[str] = None,
+ log_filter: Optional[logging.Filter] = None,
+ delay: bool = True,
+) -> logging.Handler:
+ """Create logger handler."""
+ handler: Optional[logging.Handler] = None
+
+ if file_path is not None:
+ handler = logging.FileHandler(file_path, delay=delay)
+ elif stream is not None:
+ handler = logging.StreamHandler(stream)
+
+ if handler is None:
+ raise Exception("Unable to create logging handler")
+
+ if log_level:
+ handler.setLevel(log_level)
+
+ if log_format:
+ handler.setFormatter(logging.Formatter(log_format))
+
+ if log_filter:
+ handler.addFilter(log_filter)
+
+ return handler
+
+
+def attach_handlers(
+ handlers: List[logging.Handler], loggers: List[logging.Logger]
+) -> None:
+ """Attach handlers to the loggers."""
+ for handler in handlers:
+ for logger in loggers:
+ logger.addHandler(handler)
diff --git a/src/mlia/utils/misc.py b/src/mlia/utils/misc.py
new file mode 100644
index 0000000..de95448
--- /dev/null
+++ b/src/mlia/utils/misc.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Various util functions."""
+
+
+def yes(prompt: str) -> bool:
+ """Return true if user confirms the action."""
+ response = input(f"{prompt} [y/n]: ")
+ return response in ["y", "Y"]
diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py
new file mode 100644
index 0000000..39aca43
--- /dev/null
+++ b/src/mlia/utils/proc.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils related to process management."""
+import os
+import signal
+import subprocess
+import time
+from abc import ABC
+from abc import abstractmethod
+from contextlib import contextmanager
+from contextlib import suppress
+from pathlib import Path
+from typing import Any
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+
+class OutputConsumer(ABC):
+ """Base class for the output consumers."""
+
+ @abstractmethod
+ def feed(self, line: str) -> None:
+ """Feed new line to the consumerr."""
+
+
+class RunningCommand:
+ """Running command."""
+
+ def __init__(self, process: subprocess.Popen) -> None:
+ """Init running command instance."""
+ self.process = process
+ self._output_consumers: Optional[List[OutputConsumer]] = None
+
+ def is_alive(self) -> bool:
+ """Return true if process is still alive."""
+ return self.process.poll() is None
+
+ def exit_code(self) -> Optional[int]:
+ """Return process's return code."""
+ return self.process.poll()
+
+ def stdout(self) -> Iterable[str]:
+ """Return std output of the process."""
+ assert self.process.stdout is not None
+
+ for line in self.process.stdout:
+ yield line
+
+ def kill(self) -> None:
+ """Kill the process."""
+ self.process.kill()
+
+ def send_signal(self, signal_num: int) -> None:
+ """Send signal to the process."""
+ self.process.send_signal(signal_num)
+
+ @property
+ def output_consumers(self) -> Optional[List[OutputConsumer]]:
+ """Property output_consumers."""
+ return self._output_consumers
+
+ @output_consumers.setter
+ def output_consumers(self, output_consumers: List[OutputConsumer]) -> None:
+ """Set output consumers."""
+ self._output_consumers = output_consumers
+
+ def consume_output(self) -> None:
+ """Pass program's output to the consumers."""
+ if self.process is None or self.output_consumers is None:
+ return
+
+ for line in self.stdout():
+ for consumer in self.output_consumers:
+ with suppress():
+ consumer.feed(line)
+
+ def stop(
+ self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5
+ ) -> None:
+ """Stop execution."""
+ try:
+ if not self.is_alive():
+ return
+
+ self.process.send_signal(signal.SIGINT)
+ self.consume_output()
+
+ if not wait:
+ return
+
+ for _ in range(num_of_attempts):
+ time.sleep(interval)
+ if not self.is_alive():
+ break
+ else:
+ raise Exception("Unable to stop running command")
+ finally:
+ self._close_fd()
+
+ def _close_fd(self) -> None:
+ """Close file descriptors."""
+
+ def close(file_descriptor: Any) -> None:
+ """Check and close file."""
+ if file_descriptor is not None and hasattr(file_descriptor, "close"):
+ file_descriptor.close()
+
+ close(self.process.stdout)
+ close(self.process.stderr)
+
+ def wait(self, redirect_output: bool = False) -> None:
+ """Redirect process output to stdout and wait for completion."""
+ if redirect_output:
+ for line in self.stdout():
+ print(line, end="")
+
+ self.process.wait()
+
+
+class CommandExecutor:
+ """Command executor."""
+
+ @staticmethod
+ def execute(command: List[str]) -> Tuple[int, bytes, bytes]:
+ """Execute the command."""
+ result = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
+ )
+
+ return (result.returncode, result.stdout, result.stderr)
+
+ @staticmethod
+ def submit(command: List[str]) -> RunningCommand:
+ """Submit command for the execution."""
+ process = subprocess.Popen( # pylint: disable=consider-using-with
+ command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT, # redirect command stderr to stdout
+ universal_newlines=True,
+ bufsize=1,
+ )
+
+ return RunningCommand(process)
+
+
+@contextmanager
+def working_directory(
+ working_dir: Path, create_dir: bool = False
+) -> Generator[Path, None, None]:
+ """Temporary change working directory."""
+ current_working_dir = Path.cwd()
+
+ if create_dir:
+ working_dir.mkdir()
+
+ os.chdir(working_dir)
+
+ try:
+ yield working_dir
+ finally:
+ os.chdir(current_working_dir)
diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py
new file mode 100644
index 0000000..9b63928
--- /dev/null
+++ b/src/mlia/utils/types.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Types related utility functions."""
+from typing import Any
+from typing import Optional
+
+
+def is_list_of(data: Any, cls: type, elem_num: Optional[int] = None) -> bool:
+ """Check if data is a list of object of the same class."""
+ return (
+ isinstance(data, (tuple, list))
+ and all(isinstance(item, cls) for item in data)
+ and (elem_num is None or len(data) == elem_num)
+ )
+
+
+def is_number(value: str) -> bool:
+ """Return true if string contains a number."""
+ try:
+ float(value)
+ except ValueError:
+ return False
+
+ return True
+
+
+def parse_int(value: Any, default: Optional[int] = None) -> Optional[int]:
+ """Parse integer value."""
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return default
+
+
+def only_one_selected(*options: bool) -> bool:
+ """Return true if only one True value found."""
+ return sum(options) == 1