aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGergely Nagy <gergely.nagy@arm.com>2023-06-22 14:35:21 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:16:11 +0100
commitbaaf4de286762c1955c874f78cd802d4703a8ba5 (patch)
tree3b80f906672f91e7e24723720b2d164d360f3edf
parent3cd84481fa25e64c29e57396d4bf32d7a3ca490a (diff)
downloadmlia-baaf4de286762c1955c874f78cd802d4703a8ba5.tar.gz
Re-factoring of rewrite management & added metrics
- List available rewrites - Refactor/rename 'Rewrite' class to 'RewritingOptimizer' - Introduce a registry for rewrite functions - Refactor 'Rewriter' to use the registry to look up rewrite functions - Remove mentions of hardcoded "fully_connected" from CLI help and error messages, using the registry instead - Add unit tests - Enable rewrites for all targets: Extract optimization (including rewrite specific code) from the Ethos-U-specific data collector into OptimizingDataCollector. This is reused in other targets' collectors, such as TOSA and Cortex-A. - Add more logging for rewrite - add display of MAE and NRMSE values for the trained result - add total model MAE and NRMSE metric Resolves: MLIA-891, MLIA-899, MLIA-906 Change-Id: Ie798749e1ed60cab14fdb6d9c2271c833960e93f Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
-rw-r--r--src/mlia/cli/options.py11
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py154
-rw-r--r--src/mlia/nn/rewrite/core/train.py21
-rw-r--r--src/mlia/nn/select.py4
-rw-r--r--src/mlia/target/common/optimization.py219
-rw-r--r--src/mlia/target/cortex_a/advisor.py16
-rw-r--r--src/mlia/target/ethos_u/advisor.py41
-rw-r--r--src/mlia/target/ethos_u/data_collection.py135
-rw-r--r--src/mlia/target/tosa/advisor.py18
-rw-r--r--src/mlia/utils/registry.py4
-rw-r--r--tests/test_cli_commands.py6
-rw-r--r--tests/test_common_optimization.py67
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py80
-rw-r--r--tests/test_nn_rewrite_core_train.py2
-rw-r--r--tests/test_target_cortex_a_advisor.py24
-rw-r--r--tests/test_target_ethos_u_data_collection.py62
-rw-r--r--tests/test_target_tosa_advisor.py22
-rw-r--r--tests/test_utils_registry.py3
18 files changed, 647 insertions, 242 deletions
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index 7b3b373..57f54dd 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -14,6 +14,7 @@ from mlia.backend.manager import get_available_backends
from mlia.core.common import AdviceCategory
from mlia.core.errors import ConfigurationError
from mlia.core.typing import OutputFormat
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
@@ -111,7 +112,10 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
multi_optimization_group.add_argument(
"--rewrite-target",
type=str,
- help="Type of rewrite to apply to the subgraph/layer.",
+ help=(
+ "Type of rewrite to apply to the subgraph/layer. "
+ f"Available rewrites: {RewritingOptimizer.builtin_rewrite_names()}"
+ ),
)
multi_optimization_group.add_argument(
@@ -327,9 +331,10 @@ def parse_optimization_parameters( # pylint: disable=too-many-arguments
]
if rewrite:
- if rewrite_target not in ["remove", "fully_connected"]:
+ if rewrite_target not in RewritingOptimizer.builtin_rewrite_names():
raise ConfigurationError(
- "Currently only remove and fully_connected are supported."
+ f"Invalid rewrite target: '{rewrite_target}'. "
+ f"Supported rewrites: {RewritingOptimizer.builtin_rewrite_names()}"
)
optimizer_params.append(
{
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 6b27984..fdfd35c 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""Contains class Rewriter to replace a subgraph/layer of a model."""
+"""Contains class RewritingOptimizer to replace a subgraph/layer of a model."""
from __future__ import annotations
import importlib
@@ -9,16 +9,88 @@ import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
+from typing import Callable
+from typing import cast
+
+import tensorflow as tf
from mlia.core.errors import ConfigurationError
+from mlia.core.reporting import Column
+from mlia.core.reporting import Format
+from mlia.core.reporting import Table
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import TFLiteModel
-
+from mlia.utils.registry import Registry
logger = logging.getLogger(__name__)
+RewriteCallable = Callable[[Any, Any], tf.keras.Model]
+
+
+class Rewrite:
+ """Graph rewrite logic to be used by RewritingOptimizer."""
+
+ def __init__(self, name: str, rewrite_fn: RewriteCallable):
+ """Initialize a Rewrite instance with a given name and an optional function."""
+ self.name = name
+ self.function = rewrite_fn
+
+ def __call__(self, input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ """Perform the rewrite operation using the configured function."""
+ try:
+ return self.function(input_shape, output_shape)
+ except Exception as ex:
+ raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
+
+
+@dataclass
+class DynamicallyLoadedRewrite(Rewrite):
+ """A rewrite which can load logic from a function loaded dynamically."""
+
+ def __init__(self, name: str, function_name: str):
+ """Initialize."""
+
+ def load_and_run(input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ """Load the function from a file dynamically."""
+ self.load_function(function_name)
+ return self.function(input_shape, output_shape)
+
+ super().__init__(name, load_and_run)
+
+ def load_function(self, function_name: str) -> RewriteCallable:
+ """Return the rewrite function. Import using the auto_load attr if necessary."""
+ try:
+ name_parts = function_name.split(".")
+ module_name = ".".join(name_parts[:-1])
+ fn_name = name_parts[-1]
+ module = importlib.import_module(module_name)
+ self.function = cast(RewriteCallable, getattr(module, fn_name))
+ return self.function
+ except Exception as ex:
+ raise RuntimeError(
+ f"Unable to load rewrite function '{function_name}' for '{self.name}'."
+ ) from ex
+
+
+class RewriteRegistry(Registry[Rewrite]):
+ """Registry rewrite functions."""
+
+ def __init__(self, rewrites: list[Rewrite] | None = None):
+ """Set up a rewrite registry.
+
+ Can optionally initialise with name->function pairs
+ to be automatically loaded on demand
+ """
+ super().__init__()
+ if rewrites:
+ for rewrite in rewrites:
+ self.register_rewrite(rewrite)
+
+ def register_rewrite(self, rewrite: Rewrite) -> bool:
+ """Register a rewrite."""
+ return super().register(rewrite.name, rewrite)
@dataclass
@@ -35,34 +107,35 @@ class RewriteConfiguration(OptimizerConfiguration):
return f"rewrite: {self.optimization_target}"
-class Rewriter(Optimizer):
- """Rewriter class for basic rewrite flow."""
+class RewritingOptimizer(Optimizer):
+ """RewritingOptimizer class for basic rewrite flow."""
+
+ registry = RewriteRegistry(
+ [
+ DynamicallyLoadedRewrite(
+ "fully_connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model"
+ )
+ ]
+ )
def __init__(
self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration
):
- """Init Rewriter instance."""
+ """Init RewritingOptimizer instance."""
self.model = TFLiteModel(tflite_model_path)
self.model_path = tflite_model_path
self.optimizer_configuration = optimizer_configuration
- def apply_optimization(self) -> None:
- """Apply the rewrite flow."""
+ @classmethod
+ def builtin_rewrite_names(cls) -> list:
+ """Return all registered rewrite names."""
+ return cls.registry.names()
- def get_function(arg: str) -> Any:
- module_name = ".".join(arg.split(".")[:-1])
- fn_name = arg.split(".")[-1]
- module = importlib.import_module(module_name)
- return getattr(module, fn_name)
-
- if self.optimizer_configuration.optimization_target == "fully_connected":
- replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model"
- else:
- raise ConfigurationError(
- "Only fully_connected replacement is supported in rewrite module."
- )
-
- replace_fn = get_function(replace_function)
+ def apply_optimization(self) -> None: # pylint: disable=too-many-locals
+ """Apply the rewrite flow."""
+ rewrite = RewritingOptimizer.registry.items[
+ self.optimizer_configuration.optimization_target
+ ]
use_unmodified_model = True
tflite_model = self.model.model_path
@@ -75,25 +148,48 @@ class Rewriter(Optimizer):
raise ConfigurationError(
"Input and output tensor names need to be set for rewrite."
)
- result = train(
+
+ orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
- replace_fn=replace_fn,
+ replace_fn=rewrite,
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
)
- self.model = TFLiteModel(tmp_output)
+ if orig_vs_repl_stats:
+ orig_vs_repl = ["Replaced sub-graph only"] + [
+ f"{stat:.3f}" for stat in orig_vs_repl_stats
+ ]
+ total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats]
+ notes = (
+ "These metrics show the difference between original model\n"
+ "and the model optimized by the rewrite. The models are\n"
+ "compared at two positions: directly after the replaced\n"
+ "sub-graph and at the model output.\n"
+ "MAE = Mean Absolute Error\n"
+ "NRMSE = Normalized Root Mean Square Error"
+ )
- if result:
- stats_as_str = ", ".join(str(stats) for stats in result)
- logger.info(
- "The MAE and NRMSE between original and replacement [%s]",
- stats_as_str,
+ table = Table(
+ columns=[
+ Column(
+ "Original vs. optimized",
+ alias="metric",
+ fmt=Format(wrap_width=40),
+ ),
+ Column("MAE", alias="value", fmt=Format(wrap_width=15)),
+ Column("NRMSE", alias="value", fmt=Format(wrap_width=15)),
+ ],
+ rows=[orig_vs_repl, total],
+ name="Rewrite performance metrics",
+ alias="rewrite_performance_metrics",
+ notes=notes,
)
+ logger.info(table.to_plain_text(show_title=True))
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 42bf653..82af747 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -136,12 +136,27 @@ def train(
output_filename = output_model
join_in_dir(train_dir, filename, output_filename)
+ # Assess the output diff between the parts after the rewrite subgraph
+ # in original and optimized model
+ optimized_end_path = Path(train_dir, "optimized_end.tfrec")
+ end_path = Path(train_dir, "end.tfrec")
+
+ record_model(
+ str(input_tfrec),
+ output_filename,
+ optimized_end_path,
+ num_procs=train_params.num_procs,
+ num_threads=train_params.num_threads,
+ )
+ mae, nrmse = diff_stats(end_path, str(optimized_end_path))
+
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
- return (
- results if train_params.checkpoint_at else results[0]
- ) # only return a list if multiple checkpoints are asked for
+ return (results if train_params.checkpoint_at else results[0]), [
+ mae,
+ nrmse,
+ ] # only return a list if multiple checkpoints are asked for
def eval_in_dir(
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 983426b..6947206 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -16,7 +16,7 @@ from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
-from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.config import TFLiteModel
@@ -132,7 +132,7 @@ def get_optimizer(
return Clusterer(model, config)
if isinstance(config, RewriteConfiguration):
- return Rewriter(model, config)
+ return RewritingOptimizer(model, config)
if isinstance(config, OptimizationSettings):
return _get_optimizer(model, cast(OptimizationSettings, config))
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
new file mode 100644
index 0000000..5f359c5
--- /dev/null
+++ b/src/mlia/target/common/optimization.py
@@ -0,0 +1,219 @@
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Data collector support for performance optimizations."""
+from __future__ import annotations
+
+import logging
+from abc import abstractmethod
+from functools import partial
+from pathlib import Path
+from typing import Any
+from typing import Callable
+
+from mlia.core.common import DataItem
+from mlia.core.context import Context
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.performance import estimate_performance
+from mlia.core.performance import P
+from mlia.core.performance import PerformanceEstimator
+from mlia.nn.select import get_optimizer
+from mlia.nn.select import OptimizationSettings
+from mlia.nn.tensorflow.config import get_keras_model
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.target.config import TargetProfile
+from mlia.utils.types import is_list_of
+
+logger = logging.getLogger(__name__)
+
+
+class OptimizingDataCollector(ContextAwareDataCollector):
+ """Collect performance metrics for the optimizations."""
+
+ def __init__(
+ self,
+ model: Path,
+ target_config: TargetProfile,
+ backends: list[str] | None = None,
+ ) -> None:
+ """Init performance optimizations data collector."""
+ self.model = model
+ self.target = target_config
+ self.backends = backends
+
+ def collect_data(self) -> DataItem:
+ """Collect performance metrics for the optimizations."""
+ logger.info("Estimate performance ...")
+
+ optimizations = self._get_optimization_settings(self.context)
+
+ if not optimizations or optimizations == [[]]:
+ raise FunctionalityNotSupportedError(
+ reason="No optimization targets provided",
+ description="Unable to estimate model optimizations impact",
+ )
+
+ opt_settings = self._parse_optimization_params(optimizations)
+
+ optimization_types = {
+ setting.optimization_type for opt in opt_settings for setting in opt
+ }
+
+ if optimization_types != {"rewrite"}:
+ try:
+ model = get_keras_model(self.model, self.context)
+ except NotImplementedError as err:
+ raise FunctionalityNotSupportedError(
+ reason=f"{self.model} is not a Keras model and "
+ "could not be converted to a Keras model",
+ description="Unable to run model optimizations",
+ ) from err
+ else:
+ model = self.model # type: ignore
+
+ optimizers: list[Callable] = [
+ partial(self.optimize_model, opts) for opts in opt_settings
+ ]
+
+ return self.optimize_and_estimate_performance(model, optimizers, opt_settings)
+
+ def optimize_model(
+ self, opt_settings: list[OptimizationSettings], model: KerasModel | TFLiteModel
+ ) -> Any:
+ """Run optimization."""
+ optimizer = get_optimizer(model, opt_settings)
+
+ opts_as_str = ", ".join(str(opt) for opt in opt_settings)
+ logger.info("Applying model optimizations - [%s]", opts_as_str)
+ optimizer.apply_optimization()
+
+ model = optimizer.get_model() # type: ignore
+
+ if isinstance(model, Path):
+ return model
+
+ if isinstance(model, TFLiteModel):
+ model_path = self.context.get_model_path("optimized_model.tflite")
+ with open(model.model_path, "rb") as file_handle:
+ model_data = bytearray(file_handle.read())
+ save_tflite_model(model_data, model_path)
+ return TFLiteModel(model_path)
+
+ model_path = self.context.get_model_path("optimized_model.h5")
+ save_keras_model(model, model_path)
+ return KerasModel(model_path)
+
+ def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ OptimizingDataCollector.name(),
+ "optimizations",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
+ @staticmethod
+ def _parse_optimization_params(
+ optimizations: list[list[dict]],
+ ) -> list[list[OptimizationSettings]]:
+ """Parse optimization parameters."""
+ if not is_list_of(optimizations, list):
+ raise TypeError("Optimization parameters expected to be a list.")
+
+ return [
+ [
+ OptimizationSettings(
+ item.get("optimization_type"), # type: ignore
+ item.get("optimization_target"), # type: ignore
+ item.get("layers_to_optimize"),
+ item.get("dataset"),
+ )
+ for item in opt_configuration
+ ]
+ for opt_configuration in optimizations
+ ]
+
+ def optimize_and_estimate_performance(
+ self,
+ model: KerasModel | Path,
+ optimizers: list[Callable],
+ _: list[list[OptimizationSettings]],
+ ) -> DataItem:
+ """Run optimizers and estimate perfomance on the results."""
+ for optimizer in optimizers:
+ optimizer(model)
+
+ return {}
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "common_optimizations"
+
+
+class OptimizingPerformaceDataCollector(OptimizingDataCollector):
+ """Collect performance metrics for the optimizations."""
+
+ @abstractmethod
+ def create_estimator(self) -> PerformanceEstimator:
+ """Create a PerformanceEstimator, to be overridden in subclasses."""
+
+ @abstractmethod
+ def create_optimization_performance_metrics(
+ self, original_metrics: P, optimizations_perf_metrics: list[P]
+ ) -> Any:
+ """Create an optimization metrics object."""
+
+ def optimize_and_estimate_performance(
+ self,
+ model: KerasModel | Path,
+ optimizers: list[Callable],
+ opt_settings: list[list[OptimizationSettings]],
+ ) -> Any:
+ """Run optimizers and estimate perfomance on the results."""
+ estimator = self.create_estimator()
+
+ original_metrics, *optimized_metrics = estimate_performance(
+ model, estimator, optimizers
+ )
+
+ return self.create_optimization_performance_metrics(
+ original_metrics,
+ list(zip(opt_settings, optimized_metrics)),
+ )
+
+
+_DEFAULT_OPTIMIZATION_TARGETS = [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ {
+ "optimization_type": "clustering",
+ "optimization_target": 32,
+ "layers_to_optimize": None,
+ },
+]
+
+
+def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None:
+ """Add common optimization parameters."""
+ optimization_targets = extra_args.get("optimization_targets")
+ if not optimization_targets:
+ optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
+
+ if not is_list_of(optimization_targets, dict):
+ raise TypeError("Optimization targets value has wrong format.")
+
+ advisor_parameters.update(
+ {
+ "common_optimizations": {
+ "optimizations": [optimization_targets],
+ },
+ }
+ )
diff --git a/src/mlia/target/cortex_a/advisor.py b/src/mlia/target/cortex_a/advisor.py
index db07b96..eb7720a 100644
--- a/src/mlia/target/cortex_a/advisor.py
+++ b/src/mlia/target/cortex_a/advisor.py
@@ -16,6 +16,8 @@ from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
+from mlia.target.common.optimization import add_common_optimization_params
+from mlia.target.common.optimization import OptimizingDataCollector
from mlia.target.cortex_a.advice_generation import CortexAAdviceProducer
from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.cortex_a.data_analysis import CortexADataAnalyzer
@@ -50,9 +52,7 @@ class CortexAInferenceAdvisor(DefaultInferenceAdvisor):
)
if context.category_enabled(AdviceCategory.OPTIMIZATION):
- raise RuntimeError(
- "Model optimizations are currently not supported for Cortex-A."
- )
+ collectors.append(OptimizingDataCollector(model, target_config))
return collectors
@@ -82,20 +82,22 @@ def configure_and_get_cortexa_advisor(
context: ExecutionContext,
target_profile: str | Path,
model: str | Path,
- **_extra_args: Any,
+ **extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure Cortex-A advisor."""
if context.event_handlers is None:
context.event_handlers = [CortexAEventHandler()]
if context.config_parameters is None:
- context.config_parameters = _get_config_parameters(model, target_profile)
+ context.config_parameters = _get_config_parameters(
+ model, target_profile, **extra_args
+ )
return CortexAInferenceAdvisor()
def _get_config_parameters(
- model: str | Path, target_profile: str | Path
+ model: str | Path, target_profile: str | Path, **extra_args: Any
) -> dict[str, Any]:
"""Get configuration parameters for the advisor."""
advisor_parameters: dict[str, Any] = {
@@ -104,5 +106,5 @@ def _get_config_parameters(
"target_profile": target_profile,
},
}
-
+ add_common_optimization_params(advisor_parameters, extra_args)
return advisor_parameters
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index 321734c..9f5b3a6 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -17,6 +17,8 @@ from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.target.common.optimization import add_common_optimization_params
+from mlia.target.common.optimization import OptimizingDataCollector
from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
from mlia.target.ethos_u.config import EthosUConfiguration
@@ -65,9 +67,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
)
collectors.append(
- EthosUOptimizationPerformance(
- model, target_config, optimization_settings, backends
- )
+ EthosUOptimizationPerformance(model, target_config, backends)
)
if context.category_enabled(AdviceCategory.PERFORMANCE):
collectors.append(EthosUPerformance(model, target_config, backends))
@@ -76,9 +76,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
if context.category_enabled(AdviceCategory.OPTIMIZATION):
optimization_settings = self._get_optimization_settings(context)
collectors.append(
- EthosUOptimizationPerformance(
- model, target_config, optimization_settings, backends
- )
+ EthosUOptimizationPerformance(model, target_config, backends)
)
elif context.category_enabled(AdviceCategory.PERFORMANCE):
collectors.append(EthosUPerformance(model, target_config, backends))
@@ -115,7 +113,7 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
def _get_optimization_settings(self, context: Context) -> list[list[dict]]:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
- EthosUOptimizationPerformance.name(),
+ OptimizingDataCollector.name(),
"optimizations",
expected_type=list,
expected=False,
@@ -151,20 +149,6 @@ def configure_and_get_ethosu_advisor(
return EthosUInferenceAdvisor()
-_DEFAULT_OPTIMIZATION_TARGETS = [
- {
- "optimization_type": "pruning",
- "optimization_target": 0.5,
- "layers_to_optimize": None,
- },
- {
- "optimization_type": "clustering",
- "optimization_target": 32,
- "layers_to_optimize": None,
- },
-]
-
-
def _get_config_parameters(
model: str | Path,
target_profile: str | Path,
@@ -186,19 +170,6 @@ def _get_config_parameters(
advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
- optimization_targets = extra_args.get("optimization_targets")
- if not optimization_targets:
- optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
-
- if not is_list_of(optimization_targets, dict):
- raise ValueError("Optimization targets value has wrong format.")
-
- advisor_parameters.update(
- {
- "ethos_u_model_optimizations": {
- "optimizations": [optimization_targets],
- },
- }
- )
+ add_common_optimization_params(advisor_parameters, extra_args)
return advisor_parameters
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index 4ea6120..2b41d7d 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -6,30 +6,23 @@ from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
+from typing import cast
from mlia.backend.vela.compat import Operators
from mlia.backend.vela.compat import supported_operators
-from mlia.core.context import Context
from mlia.core.data_collection import ContextAwareDataCollector
-from mlia.core.errors import FunctionalityNotSupportedError
-from mlia.core.performance import estimate_performance
-from mlia.nn.select import get_optimizer
-from mlia.nn.select import OptimizationSettings
-from mlia.nn.tensorflow.config import get_keras_model
+from mlia.core.performance import P
+from mlia.core.performance import PerformanceEstimator
from mlia.nn.tensorflow.config import get_tflite_model
-from mlia.nn.tensorflow.config import KerasModel
-from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.nn.tensorflow.utils import is_tflite_model
-from mlia.nn.tensorflow.utils import save_keras_model
-from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.target.common.optimization import OptimizingPerformaceDataCollector
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import EthosUPerformanceEstimator
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
from mlia.target.ethos_u.performance import PerformanceMetrics
from mlia.utils.logging import log_action
-from mlia.utils.types import is_list_of
logger = logging.getLogger(__name__)
@@ -96,118 +89,26 @@ class EthosUPerformance(ContextAwareDataCollector):
return "ethos_u_performance"
-class OptimizeModel:
- """Helper class for model optimization."""
+# pylint: disable=too-many-ancestors
+class EthosUOptimizationPerformance(OptimizingPerformaceDataCollector):
+ """Collect performance metrics for performance optimizations."""
- def __init__(
- self, context: Context, opt_settings: list[OptimizationSettings]
- ) -> None:
- """Init helper."""
- self.context = context
- self.opt_settings = opt_settings
-
- def __call__(self, model: KerasModel | TFLiteModel) -> Any:
- """Run optimization."""
- optimizer = get_optimizer(model, self.opt_settings)
-
- opts_as_str = ", ".join(str(opt) for opt in self.opt_settings)
- logger.info("Applying model optimizations - [%s]", opts_as_str)
- optimizer.apply_optimization()
- model = optimizer.get_model() # type: ignore
-
- if isinstance(model, Path):
- return model
-
- if isinstance(model, TFLiteModel):
- model_path = self.context.get_model_path("optimized_model.tflite")
- with open(model.model_path, "rb") as file_handle:
- model_data = bytearray(file_handle.read())
- save_tflite_model(model_data, model_path)
- return TFLiteModel(model_path)
-
- model_path = self.context.get_model_path("optimized_model.h5")
- save_keras_model(model, model_path)
- return KerasModel(model_path)
-
-
-class EthosUOptimizationPerformance(ContextAwareDataCollector):
- """Collect performance metrics for the optimizations."""
-
- def __init__(
- self,
- model: Path,
- target_config: EthosUConfiguration,
- optimizations: list[list[dict]],
- backends: list[str] | None = None,
- ) -> None:
- """Init performance optimizations data collector."""
- self.model = model
- self.target = target_config
- self.optimizations = optimizations
- self.backends = backends
-
- def collect_data(self) -> OptimizationPerformanceMetrics | None:
- """Collect performance metrics for the optimizations."""
- logger.info("Estimate performance ...")
-
- if not self.optimizations:
- raise FunctionalityNotSupportedError(
- reason="Unable to estimate model optimizations impact",
- description="No optimization targets provided",
- )
-
- opt_settings = self._parse_optimization_params(self.optimizations)
-
- if opt_settings[0][0].optimization_type != "rewrite":
- try:
- model = get_keras_model(self.model, self.context)
- except NotImplementedError as err:
- raise FunctionalityNotSupportedError(
- reason="Unable to run model optimizations",
- description=f"{self.model} is not a Keras model and "
- "could not be converted to a Keras model",
- ) from err
- else:
- model = self.model # type: ignore
-
- optimizers = [OptimizeModel(self.context, opts) for opts in opt_settings]
-
- estimator = EthosUPerformanceEstimator(
+ def create_estimator(self) -> PerformanceEstimator:
+ """Create a PerformanceEstimator, to be overridden in subclasses."""
+ return EthosUPerformanceEstimator(
self.context,
- self.target,
+ cast(EthosUConfiguration, self.target),
self.backends,
)
- original_metrics, *optimized_metrics = estimate_performance(
- model, estimator, optimizers # type: ignore
- )
-
- result = OptimizationPerformanceMetrics(
- original_perf_metrics=original_metrics,
- optimizations_perf_metrics=list(zip(opt_settings, optimized_metrics)),
+ def create_optimization_performance_metrics(
+ self, original_metrics: P, optimizations_perf_metrics: list[P]
+ ) -> Any:
+ """Create an optimization metrics object."""
+ return OptimizationPerformanceMetrics(
+ original_perf_metrics=original_metrics, # type: ignore
+ optimizations_perf_metrics=optimizations_perf_metrics, # type: ignore
)
- return result
-
- @staticmethod
- def _parse_optimization_params(
- optimizations: list[list[dict]],
- ) -> list[list[OptimizationSettings]]:
- """Parse optimization parameters."""
- if not is_list_of(optimizations, list):
- raise ValueError("Optimization parameters expected to be a list.")
-
- return [
- [
- OptimizationSettings(
- item.get("optimization_type"), # type: ignore
- item.get("optimization_target"), # type: ignore
- item.get("layers_to_optimize"),
- item.get("dataset"),
- )
- for item in opt_configuration
- ]
- for opt_configuration in optimizations
- ]
@classmethod
def name(cls) -> str:
diff --git a/src/mlia/target/tosa/advisor.py b/src/mlia/target/tosa/advisor.py
index 2d5163e..3619a83 100644
--- a/src/mlia/target/tosa/advisor.py
+++ b/src/mlia/target/tosa/advisor.py
@@ -16,6 +16,8 @@ from mlia.core.context import ExecutionContext
from mlia.core.data_analysis import DataAnalyzer
from mlia.core.data_collection import DataCollector
from mlia.core.events import Event
+from mlia.target.common.optimization import add_common_optimization_params
+from mlia.target.common.optimization import OptimizingDataCollector
from mlia.target.registry import profile
from mlia.target.tosa.advice_generation import TOSAAdviceProducer
from mlia.target.tosa.config import TOSAConfiguration
@@ -49,9 +51,9 @@ class TOSAInferenceAdvisor(DefaultInferenceAdvisor):
)
if context.category_enabled(AdviceCategory.OPTIMIZATION):
- raise RuntimeError(
- "Model optimizations are currently not supported for TOSA."
- )
+ target_profile = self.get_target_profile(context)
+ target_config = profile(target_profile)
+ collectors.append(OptimizingDataCollector(model, target_config))
return collectors
@@ -85,20 +87,22 @@ def configure_and_get_tosa_advisor(
context: ExecutionContext,
target_profile: str | Path,
model: str | Path,
- **_extra_args: Any,
+ **extra_args: Any,
) -> InferenceAdvisor:
"""Create and configure TOSA advisor."""
if context.event_handlers is None:
context.event_handlers = [TOSAEventHandler()]
if context.config_parameters is None:
- context.config_parameters = _get_config_parameters(model, target_profile)
+ context.config_parameters = _get_config_parameters(
+ model, target_profile, **extra_args
+ )
return TOSAInferenceAdvisor()
def _get_config_parameters(
- model: str | Path, target_profile: str | Path
+ model: str | Path, target_profile: str | Path, **extra_args: Any
) -> dict[str, Any]:
"""Get configuration parameters for the advisor."""
advisor_parameters: dict[str, Any] = {
@@ -107,5 +111,5 @@ def _get_config_parameters(
"target_profile": target_profile,
}
}
-
+ add_common_optimization_params(advisor_parameters, extra_args)
return advisor_parameters
diff --git a/src/mlia/utils/registry.py b/src/mlia/utils/registry.py
index a886376..1303ed7 100644
--- a/src/mlia/utils/registry.py
+++ b/src/mlia/utils/registry.py
@@ -37,3 +37,7 @@ class Registry(Generic[T]):
def pretty_name(self, name: str) -> str:
"""Get the pretty name (if available) or return the name as is otherwise."""
return self.pretty_names[name] if name in self.pretty_names else name
+
+ def names(self) -> list[str]:
+ """Sorted list of registered item names."""
+ return sorted(list(self.items.keys()))
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index e4bbe91..6b1f19d 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -3,6 +3,7 @@
"""Tests for cli.commands module."""
from __future__ import annotations
+import re
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
@@ -116,7 +117,10 @@ def test_performance_unknown_target(
"node_y",
pytest.raises(
Exception,
- match=(r"Currently only remove and fully_connected are supported."),
+ match=re.escape(
+ "Invalid rewrite target: 'random'. "
+ "Supported rewrites: ['fully_connected']"
+ ),
),
],
[
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
new file mode 100644
index 0000000..599610d
--- /dev/null
+++ b/tests/test_common_optimization.py
@@ -0,0 +1,67 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the common optimization module."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+from mlia.nn.common import Optimizer
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.config import TargetProfile
+
+
+class FakeOptimizer(Optimizer):
+ """Optimizer for testing purposes."""
+
+ def __init__(self, optimized_model_path: Path) -> None:
+ """Initialize."""
+ super().__init__()
+ self.optimized_model_path = optimized_model_path
+ self.invocation_count = 0
+
+ def apply_optimization(self) -> None:
+ """Count the invocations."""
+ self.invocation_count += 1
+
+ def get_model(self) -> TFLiteModel:
+ """Return optimized model."""
+ return TFLiteModel(self.optimized_model_path)
+
+ def optimization_config(self) -> str:
+ """Return something: doesn't matter, not used."""
+ return ""
+
+
+def test_optimizing_data_collector(
+ test_keras_model: Path,
+ test_tflite_model: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test OptimizingDataCollector, base support for various targets."""
+ optimizations = [
+ [
+ {"optimization_type": "fake", "optimization_target": 42},
+ ]
+ ]
+ context = ExecutionContext(
+ config_parameters={"common_optimizations": {"optimizations": optimizations}}
+ )
+
+ target_profile = MagicMock(spec=TargetProfile)
+
+ fake_optimizer = FakeOptimizer(test_tflite_model)
+
+ monkeypatch.setattr(
+ "mlia.target.common.optimization.get_optimizer",
+ MagicMock(return_value=fake_optimizer),
+ )
+
+ collector = OptimizingDataCollector(test_keras_model, target_profile)
+
+ collector.set_context(context)
+ collector.collect_data()
+
+ assert fake_optimizer.invocation_count == 1
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 2542db2..d4aac56 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -6,15 +6,35 @@ from __future__ import annotations
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import cast
import pytest
+from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import Rewrite
+from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
-from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.rewrite.core.rewrite import RewriteRegistry
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import TestTrainingParameters
+def mock_rewrite_function(*_: Any) -> Any:
+ """Mock function to test autoloading of rewrite functions."""
+
+
+def test_rewrite() -> None:
+ """Test the Rewrite class."""
+
+ def bad_rewrite_func() -> Any:
+ raise NotImplementedError()
+
+ rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ with pytest.raises(RuntimeError):
+ rewrite((1, 2), (1, 2))
+
+
@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
@@ -35,9 +55,9 @@ def test_rewrite_configuration(
assert config_obj.optimization_target in str(config_obj)
- rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
- assert isinstance(rewriter_obj, Rewriter)
+ assert isinstance(rewriter_obj, RewritingOptimizer)
def test_rewriting_optimizer(
@@ -52,8 +72,60 @@ def test_rewriting_optimizer(
train_params=TestTrainingParameters(),
)
- test_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
test_obj.apply_optimization()
trained_model = test_obj.get_model()
assert isinstance(trained_model, TFLiteModel)
+
+ cfg = test_obj.optimization_config()
+ assert isinstance(cfg, str)
+ assert cfg
+
+
+def test_register_rewrite_function() -> None:
+ """Test adding rewrite functions and verify the are reported via the registry."""
+ registry = RewriteRegistry()
+
+ rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+
+ registry.register_rewrite(rewrite1)
+ registry.register_rewrite(rewrite2)
+ assert registry.names() == ["r1", "r2"]
+
+
+def test_builtin_rewrite_names() -> None:
+ """Test if all builtin rewrites are properly registered and returned."""
+ assert RewritingOptimizer.builtin_rewrite_names() == ["fully_connected"]
+
+
+def test_rewrite_function_autoload() -> None:
+ """Test rewrite function loading."""
+ function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
+ rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
+ assert rewrite.name == "mock_rewrite"
+
+ assert rewrite.function is not mock_rewrite_function
+ assert rewrite.load_function(function_name) is mock_rewrite_function
+ assert rewrite.function is mock_rewrite_function
+
+
+def test_rewrite_function_autoload_fail() -> None:
+ """Test rewrite function loading failure."""
+ function_name = "invalid_module.invalid_function"
+ rewrite = DynamicallyLoadedRewrite(
+ name="mock_rewrite",
+ function_name="invalid_module.invalid_function",
+ )
+ assert rewrite.name == "mock_rewrite"
+
+ with pytest.raises(Exception) as exc_info:
+ rewrite.load_function(function_name)
+
+ message = exc_info.value.args[0]
+
+ assert message == (
+ "Unable to load rewrite function 'invalid_module.invalid_function'"
+ " for 'mock_rewrite'."
+ )
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 4493671..b001a09 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -62,7 +62,7 @@ def check_train(
train_params=train_params,
)
assert len(result) == 2
- assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
+ assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
assert output_file.is_file()
diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py
index 9e0082f..6e370d6 100644
--- a/tests/test_target_cortex_a_advisor.py
+++ b/tests/test_target_cortex_a_advisor.py
@@ -31,7 +31,23 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
"cortex_a_inference_advisor": {
"model": str(test_tflite_model),
"target_profile": "cortex-a",
- }
+ },
+ "common_optimizations": {
+ "optimizations": [
+ [
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 0.5,
+ "optimization_type": "pruning",
+ },
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 32,
+ "optimization_type": "clustering",
+ },
+ ]
+ ]
+ },
}
assert isinstance(workflow, DefaultWorkflowExecutor)
@@ -43,11 +59,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None:
[
AdviceCategory.PERFORMANCE,
"Performance estimation is currently not supported for Cortex-A.",
- ],
- [
- AdviceCategory.OPTIMIZATION,
- "Model optimizations are currently not supported for Cortex-A.",
- ],
+ ]
],
)
def test_unsupported_advice_categories(
diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py
index 6244f8b..be93c26 100644
--- a/tests/test_target_ethos_u_data_collection.py
+++ b/tests/test_target_ethos_u_data_collection.py
@@ -8,9 +8,11 @@ import pytest
from mlia.backend.vela.compat import Operators
from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
from mlia.core.data_collection import DataCollector
from mlia.core.errors import FunctionalityNotSupportedError
from mlia.nn.select import OptimizationSettings
+from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance
@@ -46,6 +48,20 @@ def test_collectors_metadata(
assert collector.name() == expected_name
+def setup_optimization(optimizations: list) -> Context:
+ """Set up optimization params for the context."""
+ params: dict = {}
+ add_common_optimization_params(
+ params,
+ {
+ "optimization_targets": optimizations,
+ },
+ )
+
+ context = ExecutionContext(config_parameters=params)
+ return context
+
+
def test_operator_compatibility_collector(
sample_context: Context, test_tflite_model: Path
) -> None:
@@ -76,7 +92,6 @@ def test_performance_collector(
def test_optimization_performance_collector(
monkeypatch: pytest.MonkeyPatch,
- sample_context: Context,
test_keras_model: Path,
test_tflite_model: Path,
) -> None:
@@ -84,16 +99,14 @@ def test_optimization_performance_collector(
target = EthosUConfiguration.load_profile("ethos-u55-256")
mock_performance_estimation(monkeypatch, target)
- collector = EthosUOptimizationPerformance(
- test_keras_model,
- target,
+
+ context = setup_optimization(
[
- [
- {"optimization_type": "pruning", "optimization_target": 0.5},
- ]
+ {"optimization_type": "pruning", "optimization_target": 0.5},
],
)
- collector.set_context(sample_context)
+ collector = EthosUOptimizationPerformance(test_keras_model, target)
+ collector.set_context(context)
result = collector.collect_data()
assert isinstance(result, OptimizationPerformanceMetrics)
@@ -105,34 +118,39 @@ def test_optimization_performance_collector(
assert opt == [OptimizationSettings("pruning", 0.5, None)]
assert isinstance(metrics, PerformanceMetrics)
- collector_no_optimizations = EthosUOptimizationPerformance(
- test_keras_model,
- target,
- [],
+ context = ExecutionContext(
+ config_parameters={"common_optimizations": {"optimizations": [[]]}}
)
+
+ collector_no_optimizations = EthosUOptimizationPerformance(test_keras_model, target)
+ collector_no_optimizations.set_context(context)
with pytest.raises(FunctionalityNotSupportedError):
collector_no_optimizations.collect_data()
- collector_tflite = EthosUOptimizationPerformance(
- test_tflite_model,
- target,
+ context = setup_optimization(
[
- [
- {"optimization_type": "pruning", "optimization_target": 0.5},
- ]
+ {"optimization_type": "pruning", "optimization_target": 0.5},
],
)
- collector_tflite.set_context(sample_context)
+
+ collector_tflite = EthosUOptimizationPerformance(test_tflite_model, target)
+ collector_tflite.set_context(context)
with pytest.raises(FunctionalityNotSupportedError):
collector_tflite.collect_data()
with pytest.raises(
Exception, match="Optimization parameters expected to be a list"
):
- collector_bad_config = EthosUOptimizationPerformance(
- test_keras_model, target, {"optimization_type": "pruning"} # type: ignore
+ context = ExecutionContext(
+ config_parameters={
+ "common_optimizations": {
+ "optimizations": [{"optimization_type": "pruning"}]
+ }
+ }
)
- collector.set_context(sample_context)
+
+ collector_bad_config = EthosUOptimizationPerformance(test_keras_model, target)
+ collector_bad_config.set_context(context)
collector_bad_config.collect_data()
diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py
index f4f1e36..36e52e9 100644
--- a/tests/test_target_tosa_advisor.py
+++ b/tests/test_target_tosa_advisor.py
@@ -32,10 +32,26 @@ def test_configure_and_get_tosa_advisor(
assert advisor.get_events(ctx) == get_events_mock
assert ctx.event_handlers is not None
assert ctx.config_parameters == {
+ "common_optimizations": {
+ "optimizations": [
+ [
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 0.5,
+ "optimization_type": "pruning",
+ },
+ {
+ "layers_to_optimize": None,
+ "optimization_target": 32,
+ "optimization_type": "clustering",
+ },
+ ]
+ ]
+ },
"tosa_inference_advisor": {
"model": str(test_tflite_model),
"target_profile": "tosa",
- }
+ },
}
assert isinstance(workflow, DefaultWorkflowExecutor)
@@ -48,10 +64,6 @@ def test_configure_and_get_tosa_advisor(
AdviceCategory.PERFORMANCE,
"Performance estimation is currently not supported for TOSA.",
],
- [
- AdviceCategory.OPTIMIZATION,
- "Model optimizations are currently not supported for TOSA.",
- ],
],
)
def test_unsupported_advice_categories(
diff --git a/tests/test_utils_registry.py b/tests/test_utils_registry.py
index 95721fc..288c825 100644
--- a/tests/test_utils_registry.py
+++ b/tests/test_utils_registry.py
@@ -8,7 +8,9 @@ def test_registry() -> None:
"""Test Registry class."""
reg = Registry[str]()
assert not str(reg)
+ assert reg.names() == []
assert reg.register("name", "value")
+ assert reg.names() == ["name"]
assert not reg.register("name", "value")
assert "name" in reg.items
assert reg.items["name"] == "value"
@@ -17,3 +19,4 @@ def test_registry() -> None:
assert len(reg.items) == 2
assert "other_name" in reg.items
assert reg.items["other_name"] == "value_2"
+ assert reg.names() == ["name", "other_name"]