aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2023-02-20 15:32:54 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:42:28 +0100
commit446c379c92e15ad8f24ed0db853dd0fc9c271151 (patch)
treefb9e2b20fba15d3aa44054eb76d76fbdb1459006
parentf0b8ed75fed9dc69ab1f6313339f9f7e38bfc725 (diff)
downloadmlia-446c379c92e15ad8f24ed0db853dd0fc9c271151.tar.gz
Add a CLI component to enable rewrites
* Add flags for rewrite (--rewrite, --rewrite-start, --rewrite-end, --rewrite-target) * Refactor CLI interfaces to accept tflite models with optimize for rewrite, keras models with optimize for clustering and pruning * Refactor and move common.py and select.py out of the folder nn/tensorflow/optimizations * Add file nn/rewrite/core/rewrite.py as placeholder * Update/add unit tests * Refactor OptimizeModel in ethos_u/data_collection.py for accepting tflite model case * Extend the logic so that if "--rewrite" is specified, we don't add pruning to also accept TFLite models. * Update README.md Resolves: MLIA-750, MLIA-854, MLIA-865 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I67d85f71fa253d2bad4efe304ad8225970b9622c
-rw-r--r--README.md24
-rw-r--r--src/mlia/cli/commands.py12
-rw-r--r--src/mlia/cli/helpers.py8
-rw-r--r--src/mlia/cli/main.py2
-rw-r--r--src/mlia/cli/options.py64
-rw-r--r--src/mlia/nn/common.py (renamed from src/mlia/nn/tensorflow/optimizations/common.py)6
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py45
-rw-r--r--src/mlia/nn/select.py (renamed from src/mlia/nn/tensorflow/optimizations/select.py)55
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py6
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py4
-rw-r--r--src/mlia/target/ethos_u/advice_generation.py2
-rw-r--r--src/mlia/target/ethos_u/advisor.py16
-rw-r--r--src/mlia/target/ethos_u/data_analysis.py2
-rw-r--r--src/mlia/target/ethos_u/data_collection.py42
-rw-r--r--src/mlia/target/ethos_u/performance.py2
-rw-r--r--tests/test_cli_commands.py144
-rw-r--r--tests/test_cli_helpers.py2
-rw-r--r--tests/test_cli_main.py15
-rw-r--r--tests/test_cli_options.py7
-rw-r--r--tests/test_nn_select.py (renamed from tests/test_nn_tensorflow_optimizations_select.py)6
-rw-r--r--tests/test_target_ethos_u_advice_generation.py2
-rw-r--r--tests/test_target_ethos_u_advisor.py51
-rw-r--r--tests/test_target_ethos_u_data_analysis.py2
-rw-r--r--tests/test_target_ethos_u_data_collection.py2
24 files changed, 447 insertions, 74 deletions
diff --git a/README.md b/README.md
index 725af53..7a879a9 100644
--- a/README.md
+++ b/README.md
@@ -158,10 +158,11 @@ mlia check --help
## **optimize**
-This sub-command applies optimizations to a Keras model (.h5 or SavedModel) and
-shows the performance improvements compared to the original unoptimized model.
+This sub-command applies optimizations to a Keras model (.h5 or SavedModel) or
+a TensorFlow Lite model and shows the performance improvements compared to
+the original unoptimized model.
-There are currently two optimization techniques available to apply:
+There are currently three optimization techniques available to apply:
* **pruning**: Sets insignificant model weights to zero until the specified
sparsity is reached.
@@ -172,9 +173,13 @@ More information about these techniques can be found online in the TensorFlow
documentation, e.g. in the
[TensorFlow model optimization guides](https://www.tensorflow.org/model_optimization/guide).
+* **rewrite**: Replaces certain subgraph/layer of the pre-trained model with
+ candidates from the rewrite library, with or without training using a
+ small portion of the training data, to achieve local performance gains.
+
**Note:** A ***Keras model*** (.h5 or SavedModel) is required as input to
-perform the optimizations. Models in the TensorFlow Lite format are **not**
-supported.
+perform pruning and clustering. A ***TensorFlow Lite model*** is required as input
+to perform a rewrite.
*Examples:*
@@ -189,6 +194,15 @@ mlia optimize ~/models/ds_cnn_l.h5 \
# Get help and further information
mlia optimize --help
+
+# An example for using rewrite
+mlia optimize ~/models/ds_cnn_large_fp32.tflite \
+ --target-profile ethos-u55-256 \
+ --rewrite \
+ --dataset input.tfrec \
+ --rewrite-target fully_connected \
+ --rewrite-start MobileNet/avg_pool/AvgPool \
+ --rewrite-end MobileNet/fc1/BiasAdd
```
# Target profiles
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 1f339ee..7af41d9 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -96,7 +96,7 @@ def check(
)
-def optimize( # pylint: disable=too-many-arguments
+def optimize( # pylint: disable=too-many-locals,too-many-arguments
ctx: ExecutionContext,
target_profile: str,
model: str,
@@ -104,8 +104,13 @@ def optimize( # pylint: disable=too-many-arguments
clustering: bool,
pruning_target: float | None,
clustering_target: int | None,
+ rewrite: bool | None = None,
+ rewrite_target: str | None = None,
+ rewrite_start: str | None = None,
+ rewrite_end: str | None = None,
layers_to_optimize: list[str] | None = None,
backend: list[str] | None = None,
+ dataset: Path | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -145,7 +150,12 @@ def optimize( # pylint: disable=too-many-arguments
clustering,
pruning_target,
clustering_target,
+ rewrite,
+ rewrite_target,
+ rewrite_start,
+ rewrite_end,
layers_to_optimize,
+ dataset,
)
)
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index abc6df0..824db1b 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -10,7 +10,7 @@ from typing import cast
from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.target.config import get_builtin_profile_path
from mlia.target.config import is_builtin_profile
@@ -47,7 +47,11 @@ class CLIActionResolver(ActionResolver):
) -> list[str]:
"""Return specific optimization command description."""
opt_types = " ".join("--" + opt.optimization_type for opt in opt_settings)
- opt_targs_strings = ["--pruning-target", "--clustering-target"]
+ opt_targs_strings = [
+ "--pruning-target",
+ "--clustering-target",
+ "--rewrite-target",
+ ]
opt_targs = ",".join(
f"{opt_targs_strings[i]} {opt.optimization_target}"
for i, opt in enumerate(opt_settings)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 88258d5..9e1b7cd 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -23,6 +23,7 @@ from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
from mlia.cli.options import add_check_category_options
+from mlia.cli.options import add_dataset_options
from mlia.cli.options import add_debug_options
from mlia.cli.options import add_keras_model_options
from mlia.cli.options import add_model_options
@@ -89,6 +90,7 @@ def get_commands() -> list[CommandInfo]:
add_multi_optimization_options,
add_output_options,
add_debug_options,
+ add_dataset_options,
],
),
]
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index fe177eb..7b3b373 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,6 +12,7 @@ from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
from mlia.backend.manager import get_available_backends
from mlia.core.common import AdviceCategory
+from mlia.core.errors import ConfigurationError
from mlia.core.typing import OutputFormat
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
@@ -90,6 +91,10 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
)
multi_optimization_group.add_argument(
+ "--rewrite", action="store_true", help="Apply rewrite optimization."
+ )
+
+ multi_optimization_group.add_argument(
"--pruning-target",
type=float,
help="Sparsity to be reached during optimization "
@@ -103,6 +108,24 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
f"(default: {DEFAULT_CLUSTERING_TARGET})",
)
+ multi_optimization_group.add_argument(
+ "--rewrite-target",
+ type=str,
+ help="Type of rewrite to apply to the subgraph/layer.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-start",
+ type=str,
+ help="Starting node in the graph of the subgraph to be rewritten.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-end",
+ type=str,
+ help="Ending node in the graph of the subgraph to be rewritten.",
+ )
+
def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
@@ -131,6 +154,16 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None:
)
+def add_dataset_options(parser: argparse.ArgumentParser) -> None:
+ """Addd dataset options."""
+ dataset_group = parser.add_argument_group("dataset options")
+ dataset_group.add_argument(
+ "--dataset",
+ type=Path,
+ help="The path of input tfrec file",
+ )
+
+
def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
model_group = parser.add_argument_group("Keras model options")
@@ -239,12 +272,17 @@ def add_output_directory(parser: argparse.ArgumentParser) -> None:
)
-def parse_optimization_parameters(
+def parse_optimization_parameters( # pylint: disable=too-many-arguments
pruning: bool = False,
clustering: bool = False,
pruning_target: float | None = None,
clustering_target: int | None = None,
+ rewrite: bool | None = False,
+ rewrite_target: str | None = None,
+ rewrite_start: str | None = None,
+ rewrite_end: str | None = None,
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
opt_types = []
@@ -263,7 +301,14 @@ def parse_optimization_parameters(
if not clustering_target:
clustering_target = DEFAULT_CLUSTERING_TARGET
- if (pruning is False and clustering is False) or pruning:
+ if rewrite:
+ if not rewrite_target or not rewrite_start or not rewrite_end:
+ raise ConfigurationError(
+ "To perform rewrite, rewrite-target, rewrite-start and "
+ "rewrite-end must be set."
+ )
+
+ if not any((pruning, clustering, rewrite)) or pruning:
opt_types.append("pruning")
opt_targets.append(pruning_target)
@@ -276,10 +321,25 @@ def parse_optimization_parameters(
"optimization_type": opt_type.strip(),
"optimization_target": float(opt_target),
"layers_to_optimize": layers_to_optimize,
+ "dataset": dataset,
}
for opt_type, opt_target in zip(opt_types, opt_targets)
]
+ if rewrite:
+ if rewrite_target not in ["remove", "fully_connected"]:
+ raise ConfigurationError(
+ "Currently only remove and fully_connected are supported."
+ )
+ optimizer_params.append(
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": rewrite_target,
+ "layers_to_optimize": [rewrite_start, rewrite_end],
+ "dataset": dataset,
+ }
+ )
+
return optimizer_params
diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/common.py
index 1dce0b2..5a8d685 100644
--- a/src/mlia/nn/tensorflow/optimizations/common.py
+++ b/src/mlia/nn/common.py
@@ -1,11 +1,11 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Common items for the optimizations module."""
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
-import tensorflow as tf
+from mlia.nn.tensorflow.config import ModelConfiguration
@dataclass
@@ -17,7 +17,7 @@ class Optimizer(ABC):
"""Abstract class for the optimizer."""
@abstractmethod
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> ModelConfiguration:
"""Abstract method to return the model instance from the optimizer."""
@abstractmethod
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
new file mode 100644
index 0000000..d4f61c5
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contains class Rewriter to replace a subgraph/layer of a model."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from pathlib import Path
+
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
+from mlia.nn.tensorflow.config import TFLiteModel
+
+
+@dataclass
+class RewriteConfiguration(OptimizerConfiguration):
+ """Rewrite configuration."""
+
+ optimization_target: str
+ layers_to_optimize: list[str] | None = None
+ dataset: Path | None = None
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"rewrite: {self.optimization_target}"
+
+
+class Rewriter(Optimizer):
+ """Rewriter class for basic rewrite flow."""
+
+ def __init__(
+ self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration
+ ):
+ """Init Rewriter instance."""
+ self.model = TFLiteModel(tflite_model_path)
+ self.optimizer_configuration = optimizer_configuration
+
+ def apply_optimization(self) -> None:
+ """Apply the rewrite flow."""
+
+ def get_model(self) -> TFLiteModel:
+ """Return optimized model."""
+ return self.model
+
+ def optimization_config(self) -> str:
+ """Optimization configirations."""
diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/select.py
index a78df12..7a25e47 100644
--- a/src/mlia/nn/tensorflow/optimizations/select.py
+++ b/src/mlia/nn/select.py
@@ -4,16 +4,21 @@
from __future__ import annotations
import math
+from pathlib import Path
+from typing import Any
from typing import NamedTuple
import tensorflow as tf
from mlia.core.errors import ConfigurationError
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import Rewriter
from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
-from mlia.nn.tensorflow.optimizations.common import Optimizer
-from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
from mlia.nn.tensorflow.optimizations.pruning import Pruner
from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
from mlia.utils.types import is_list_of
@@ -25,11 +30,13 @@ class OptimizationSettings(NamedTuple):
optimization_type: str
optimization_target: int | float
layers_to_optimize: list[str] | None
+ dataset: Path | None = None
@staticmethod
def create_from(
optimizer_params: list[tuple[str, float]],
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> list[OptimizationSettings]:
"""Create optimization settings from the provided parameters."""
return [
@@ -37,6 +44,7 @@ class OptimizationSettings(NamedTuple):
optimization_type=opt_type,
optimization_target=opt_target,
layers_to_optimize=layers_to_optimize,
+ dataset=dataset,
)
for opt_type, opt_target in optimizer_params
]
@@ -64,6 +72,14 @@ class OptimizationSettings(NamedTuple):
self.optimization_type, next_target, self.layers_to_optimize
)
+ if self.optimization_type == "rewrite":
+ return OptimizationSettings(
+ self.optimization_type,
+ self.optimization_target,
+ self.layers_to_optimize,
+ self.dataset,
+ )
+
raise ValueError(f"Optimization type {self.optimization_type} is unknown.")
@@ -83,7 +99,7 @@ class MultiStageOptimizer(Optimizer):
"""Return string representation of the optimization config."""
return " - ".join(str(opt) for opt in self.optimizations)
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> Any:
"""Return optimized model."""
return self.model
@@ -96,19 +112,25 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
- model: tf.keras.Model | KerasModel,
+ model: tf.keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
model = model.get_keras_model()
+ if isinstance(model, TFLiteModel):
+ model = model.model_path
+
if isinstance(config, PruningConfiguration):
return Pruner(model, config)
if isinstance(config, ClusteringConfiguration):
return Clusterer(model, config)
+ if isinstance(config, RewriteConfiguration):
+ return Rewriter(model, config) # type: ignore
+
if isinstance(config, OptimizationSettings) or is_list_of(
config, OptimizationSettings
):
@@ -118,18 +140,18 @@ def get_optimizer(
def _get_optimizer(
- model: tf.keras.Model,
+ model: tf.keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
optimizer_configs = []
- for opt_type, opt_target, layers_to_optimize in optimization_settings:
+ for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings:
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize
+ opt_type, opt_target, layers_to_optimize, dataset
)
optimizer_configs.append(opt_config)
@@ -141,15 +163,16 @@ def _get_optimizer(
def _get_optimizer_configuration(
optimization_type: str,
- optimization_target: int | float,
+ optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
opt_type = optimization_type.lower()
if opt_type == "pruning":
- return PruningConfiguration(optimization_target, layers_to_optimize)
+ return PruningConfiguration(float(optimization_target), layers_to_optimize)
if opt_type == "clustering":
# make sure an integer is given as clustering target
@@ -161,11 +184,23 @@ def _get_optimizer_configuration(
f"Optimization target provided: {optimization_target}"
)
+ if opt_type == "rewrite":
+ if isinstance(optimization_target, str):
+ return RewriteConfiguration( # type: ignore
+ str(optimization_target), layers_to_optimize, dataset
+ )
+
+ raise ConfigurationError(
+ "Optimization target should be a string indicating a"
+ "choice from rewrite library. "
+ f"Optimization target provided: {optimization_target}"
+ )
+
raise ConfigurationError(f"Unsupported optimization type: {optimization_type}")
def _check_optimizer_params(
- optimization_type: str, optimization_target: int | float
+ optimization_type: str, optimization_target: int | float | str
) -> None:
"""Check optimizer params."""
if not optimization_target:
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
index 4aaa33e..f9018b3 100644
--- a/src/mlia/nn/tensorflow/optimizations/clustering.py
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Clusterer that clusters unique weights per layer to a specified number.
@@ -18,8 +18,8 @@ from tensorflow_model_optimization.python.core.clustering.keras.experimental imp
cluster as experimental_cluster,
)
-from mlia.nn.tensorflow.optimizations.common import Optimizer
-from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
@dataclass
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index 2d5ef0e..a30b301 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -24,8 +24,8 @@ from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint
pruning_wrapper,
)
-from mlia.nn.tensorflow.optimizations.common import Optimizer
-from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
logger = logging.getLogger(__name__)
diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py
index a9f9eac..351082a 100644
--- a/src/mlia/target/ethos_u/advice_generation.py
+++ b/src/mlia/target/ethos_u/advice_generation.py
@@ -11,7 +11,7 @@ from mlia.core.advice_generation import ContextAwareAdviceProducer
from mlia.core.advice_generation import FactBasedAdviceProducer
from mlia.core.common import AdviceCategory
from mlia.core.common import DataItem
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
from mlia.target.common.reporters import handle_model_is_not_tflite_compatible_common
from mlia.target.common.reporters import handle_tflite_check_failed_common
from mlia.target.common.reporters import ModelIsNotTFLiteCompatible
diff --git a/src/mlia/target/ethos_u/advisor.py b/src/mlia/target/ethos_u/advisor.py
index d2c308a..321734c 100644
--- a/src/mlia/target/ethos_u/advisor.py
+++ b/src/mlia/target/ethos_u/advisor.py
@@ -54,8 +54,20 @@ class EthosUInferenceAdvisor(DefaultInferenceAdvisor):
if is_tflite_model(model):
# TensorFlow Lite models do not support optimization (only performance)!
if context.category_enabled(AdviceCategory.OPTIMIZATION):
- raise RuntimeError(
- "Optimizations are not supported for TensorFlow Lite files."
+ optimization_settings = self._get_optimization_settings(context)
+
+ optimization_types = {
+ opt["optimization_type"] for opt in optimization_settings[0]
+ }
+ if optimization_types != {"rewrite"}:
+ raise RuntimeError(
+ "Only 'rewrite' is supported for TensorFlow Lite files."
+ )
+
+ collectors.append(
+ EthosUOptimizationPerformance(
+ model, target_config, optimization_settings, backends
+ )
)
if context.category_enabled(AdviceCategory.PERFORMANCE):
collectors.append(EthosUPerformance(model, target_config, backends))
diff --git a/src/mlia/target/ethos_u/data_analysis.py b/src/mlia/target/ethos_u/data_analysis.py
index 3df4bff..5c6080f 100644
--- a/src/mlia/target/ethos_u/data_analysis.py
+++ b/src/mlia/target/ethos_u/data_analysis.py
@@ -10,7 +10,7 @@ from mlia.backend.vela.compat import Operators
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
from mlia.core.data_analysis import FactExtractor
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.target.common.reporters import analyze_tflite_compatibility_common
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index 0654143..0f3a8d2 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -5,6 +5,7 @@ from __future__ import annotations
import logging
from pathlib import Path
+from typing import Any
from mlia.backend.vela.compat import Operators
from mlia.backend.vela.compat import supported_operators
@@ -12,15 +13,17 @@ from mlia.core.context import Context
from mlia.core.data_collection import ContextAwareDataCollector
from mlia.core.errors import FunctionalityNotSupportedError
from mlia.core.performance import estimate_performance
+from mlia.nn.select import get_optimizer
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import get_keras_model
from mlia.nn.tensorflow.config import get_tflite_model
from mlia.nn.tensorflow.config import KerasModel
-from mlia.nn.tensorflow.optimizations.select import get_optimizer
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.tflite_compat import TFLiteChecker
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.performance import EthosUPerformanceEstimator
from mlia.target.ethos_u.performance import OptimizationPerformanceMetrics
@@ -103,7 +106,7 @@ class OptimizeModel:
self.context = context
self.opt_settings = opt_settings
- def __call__(self, keras_model: KerasModel) -> KerasModel:
+ def __call__(self, keras_model: KerasModel) -> Any:
"""Run optimization."""
optimizer = get_optimizer(keras_model, self.opt_settings)
@@ -112,9 +115,19 @@ class OptimizeModel:
optimizer.apply_optimization()
model = optimizer.get_model()
+
+ if isinstance(model, Path):
+ return model
+
+ if isinstance(model, TFLiteModel):
+ model_path = self.context.get_model_path("optimized_model.tflite")
+ with open(model.model_path, "rb") as file_handle:
+ model_data = bytearray(file_handle.read())
+ save_tflite_model(model_data, model_path)
+ return TFLiteModel(model_path)
+
model_path = self.context.get_model_path("optimized_model.h5")
save_keras_model(model, model_path)
-
return KerasModel(model_path)
@@ -146,14 +159,17 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
opt_settings = self._parse_optimization_params(self.optimizations)
- try:
- keras_model = get_keras_model(self.model, self.context)
- except NotImplementedError as err:
- raise FunctionalityNotSupportedError(
- reason="Unable to run model optimizations",
- description=f"{self.model} is not a Keras model and "
- "could not be converted to a Keras model",
- ) from err
+ if opt_settings[0][0].optimization_type != "rewrite":
+ try:
+ model = get_keras_model(self.model, self.context)
+ except NotImplementedError as err:
+ raise FunctionalityNotSupportedError(
+ reason="Unable to run model optimizations",
+ description=f"{self.model} is not a Keras model and "
+ "could not be converted to a Keras model",
+ ) from err
+ else:
+ model = self.model # type: ignore
optimizers = [OptimizeModel(self.context, opts) for opts in opt_settings]
@@ -163,7 +179,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
self.backends,
)
original_metrics, *optimized_metrics = estimate_performance(
- keras_model, estimator, optimizers # type: ignore
+ model, estimator, optimizers # type: ignore
)
result = OptimizationPerformanceMetrics(
diff --git a/src/mlia/target/ethos_u/performance.py b/src/mlia/target/ethos_u/performance.py
index f7f9a8c..a0526e4 100644
--- a/src/mlia/target/ethos_u/performance.py
+++ b/src/mlia/target/ethos_u/performance.py
@@ -15,9 +15,9 @@ from mlia.backend.corstone import is_corstone_backend
from mlia.backend.corstone.performance import estimate_performance
from mlia.core.context import Context
from mlia.core.performance import PerformanceEstimator
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import get_tflite_model
from mlia.nn.tensorflow.config import ModelConfiguration
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.registry import supported_backends
from mlia.utils.logging import log_action
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index f3213c4..6765a53 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -3,6 +3,7 @@
"""Tests for cli.commands module."""
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from unittest.mock import call
@@ -49,37 +50,148 @@ def test_performance_unknown_target(
@pytest.mark.parametrize(
- "target_profile, pruning, clustering, pruning_target, clustering_target",
+ "target_profile, pruning, clustering, pruning_target, clustering_target, "
+ "rewrite, rewrite_target, rewrite_start, rewrite_end, expected_error",
[
- ["ethos-u55-256", True, False, 0.5, None],
- ["ethos-u65-512", False, True, 0.5, 32],
- ["ethos-u55-256", True, True, 0.5, None],
- ["ethos-u55-256", False, False, 0.5, None],
- ["ethos-u55-256", False, True, "invalid", 32],
+ [
+ "ethos-u55-256",
+ True,
+ False,
+ 0.5,
+ None,
+ False,
+ None,
+ "node_a",
+ "node_b",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ None,
+ None,
+ True,
+ "fully_connected",
+ "node_a",
+ "node_b",
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ True,
+ False,
+ 0.5,
+ None,
+ True,
+ "fully_connected",
+ "node_a",
+ "node_b",
+ pytest.raises(
+ Exception,
+ match=(r"Only 'rewrite' is supported for TensorFlow Lite files."),
+ ),
+ ],
+ [
+ "ethos-u65-512",
+ False,
+ True,
+ 0.5,
+ 32,
+ False,
+ None,
+ None,
+ None,
+ does_not_raise(),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ 0.5,
+ None,
+ True,
+ "random",
+ "node_x",
+ "node_y",
+ pytest.raises(
+ Exception,
+ match=(r"Currently only remove and fully_connected are supported."),
+ ),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ 0.5,
+ None,
+ True,
+ None,
+ "node_m",
+ "node_n",
+ pytest.raises(
+ Exception,
+ match=(
+ r"To perform rewrite, rewrite-target, "
+ r"rewrite-start and rewrite-end must be set."
+ ),
+ ),
+ ],
+ [
+ "ethos-u55-256",
+ False,
+ False,
+ "invalid",
+ None,
+ True,
+ "remove",
+ None,
+ "node_end",
+ pytest.raises(
+ Exception,
+ match=(
+ r"To perform rewrite, rewrite-target, "
+ r"rewrite-start and rewrite-end must be set."
+ ),
+ ),
+ ],
],
)
-def test_opt_valid_optimization_target(
+def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
target_profile: str,
sample_context: ExecutionContext,
pruning: bool,
clustering: bool,
pruning_target: float | None,
clustering_target: int | None,
+ rewrite: bool,
+ rewrite_target: str | None,
+ rewrite_start: str | None,
+ rewrite_end: str | None,
+ expected_error: Any,
monkeypatch: pytest.MonkeyPatch,
test_keras_model: Path,
+ test_tflite_model: Path,
) -> None:
"""Test that command should not fail with valid optimization targets."""
mock_performance_estimation(monkeypatch)
- optimize(
- ctx=sample_context,
- target_profile=target_profile,
- model=str(test_keras_model),
- pruning=pruning,
- clustering=clustering,
- pruning_target=pruning_target,
- clustering_target=clustering_target,
- )
+ model_type = test_tflite_model if rewrite else test_keras_model
+
+ with expected_error:
+ optimize(
+ ctx=sample_context,
+ target_profile=target_profile,
+ model=str(model_type),
+ pruning=pruning,
+ clustering=clustering,
+ pruning_target=pruning_target,
+ clustering_target=clustering_target,
+ rewrite=rewrite,
+ rewrite_target=rewrite_target,
+ rewrite_start=rewrite_start,
+ rewrite_end=rewrite_end,
+ )
def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py
index 6d19207..494ed89 100644
--- a/tests/test_cli_helpers.py
+++ b/tests/test_cli_helpers.py
@@ -10,7 +10,7 @@ import pytest
from mlia.cli.helpers import CLIActionResolver
from mlia.cli.helpers import copy_profile_file_to_output_dir
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
class TestCliActionResolver:
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 2f89268..e415284 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -165,6 +165,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
pruning_target=None,
clustering_target=None,
backend=None,
+ rewrite=False,
+ rewrite_target=None,
+ rewrite_start=None,
+ rewrite_end=None,
+ dataset=None,
),
],
[
@@ -189,6 +194,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
pruning_target=0.5,
clustering_target=32,
backend=None,
+ rewrite=False,
+ rewrite_target=None,
+ rewrite_start=None,
+ rewrite_end=None,
+ dataset=None,
),
],
[
@@ -210,6 +220,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
pruning_target=None,
clustering_target=None,
backend=["some_backend"],
+ rewrite=False,
+ rewrite_target=None,
+ rewrite_start=None,
+ rewrite_end=None,
+ dataset=None,
),
],
[
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index c02ef89..75ace0b 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -30,6 +30,7 @@ from mlia.core.typing import OutputFormat
"optimization_type": "pruning",
"optimization_target": 0.5,
"layers_to_optimize": None,
+ "dataset": None,
}
],
],
@@ -44,6 +45,7 @@ from mlia.core.typing import OutputFormat
"optimization_type": "pruning",
"optimization_target": 0.5,
"layers_to_optimize": None,
+ "dataset": None,
}
],
],
@@ -58,6 +60,7 @@ from mlia.core.typing import OutputFormat
"optimization_type": "clustering",
"optimization_target": 32,
"layers_to_optimize": None,
+ "dataset": None,
}
],
],
@@ -72,11 +75,13 @@ from mlia.core.typing import OutputFormat
"optimization_type": "pruning",
"optimization_target": 0.5,
"layers_to_optimize": None,
+ "dataset": None,
},
{
"optimization_type": "clustering",
"optimization_target": 32,
"layers_to_optimize": None,
+ "dataset": None,
},
],
],
@@ -91,6 +96,7 @@ from mlia.core.typing import OutputFormat
"optimization_type": "pruning",
"optimization_target": 0.4,
"layers_to_optimize": None,
+ "dataset": None,
}
],
],
@@ -117,6 +123,7 @@ from mlia.core.typing import OutputFormat
"optimization_type": "clustering",
"optimization_target": 32.2,
"layers_to_optimize": None,
+ "dataset": None,
}
],
],
diff --git a/tests/test_nn_tensorflow_optimizations_select.py b/tests/test_nn_select.py
index f5ba6f0..31628d2 100644
--- a/tests/test_nn_tensorflow_optimizations_select.py
+++ b/tests/test_nn_select.py
@@ -11,13 +11,13 @@ import pytest
import tensorflow as tf
from mlia.core.errors import ConfigurationError
+from mlia.nn.select import get_optimizer
+from mlia.nn.select import MultiStageOptimizer
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
from mlia.nn.tensorflow.optimizations.pruning import Pruner
from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
-from mlia.nn.tensorflow.optimizations.select import get_optimizer
-from mlia.nn.tensorflow.optimizations.select import MultiStageOptimizer
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
@pytest.mark.parametrize(
diff --git a/tests/test_target_ethos_u_advice_generation.py b/tests/test_target_ethos_u_advice_generation.py
index 772fc56..ac4e5e9 100644
--- a/tests/test_target_ethos_u_advice_generation.py
+++ b/tests/test_target_ethos_u_advice_generation.py
@@ -12,7 +12,7 @@ from mlia.core.common import DataItem
from mlia.core.context import ExecutionContext
from mlia.core.helpers import ActionResolver
from mlia.core.helpers import APIActionResolver
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
from mlia.target.ethos_u.advice_generation import EthosUAdviceProducer
from mlia.target.ethos_u.advice_generation import EthosUStaticAdviceProducer
from mlia.target.ethos_u.data_analysis import AllOperatorsSupportedOnNPU
diff --git a/tests/test_target_ethos_u_advisor.py b/tests/test_target_ethos_u_advisor.py
index 11aefc7..20131d2 100644
--- a/tests/test_target_ethos_u_advisor.py
+++ b/tests/test_target_ethos_u_advisor.py
@@ -1,7 +1,11 @@
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Ethos-U MLIA module."""
+from __future__ import annotations
+
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
+from typing import Any
import pytest
@@ -16,16 +20,53 @@ def test_advisor_metadata() -> None:
assert EthosUInferenceAdvisor.name() == "ethos_u_inference_advisor"
-def test_unsupported_advice_categories(tmp_path: Path, test_tflite_model: Path) -> None:
+@pytest.mark.parametrize(
+ "optimization_targets, expected_error",
+ [
+ [
+ [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ }
+ ],
+ pytest.raises(
+ Exception,
+ match="Only 'rewrite' is supported for TensorFlow Lite files.",
+ ),
+ ],
+ [
+ [
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": "fully_connected",
+ "layers_to_optimize": [
+ "MobileNet/avg_pool/AvgPool",
+ "MobileNet/fc1/BiasAdd",
+ ],
+ }
+ ],
+ does_not_raise(),
+ ],
+ ],
+)
+def test_unsupported_advice_categories(
+ tmp_path: Path,
+ test_tflite_model: Path,
+ optimization_targets: list[dict[str, Any]],
+ expected_error: Any,
+) -> None:
"""Test that advisor should throw an exception for unsupported categories."""
- with pytest.raises(
- Exception, match="Optimizations are not supported for TensorFlow Lite files."
- ):
+ with expected_error:
ctx = ExecutionContext(
output_dir=tmp_path, advice_category={AdviceCategory.OPTIMIZATION}
)
advisor = configure_and_get_ethosu_advisor(
- ctx, "ethos-u55-256", str(test_tflite_model)
+ ctx,
+ "ethos-u55-256",
+ str(test_tflite_model),
+ optimization_targets=optimization_targets,
)
advisor.configure(ctx)
diff --git a/tests/test_target_ethos_u_data_analysis.py b/tests/test_target_ethos_u_data_analysis.py
index 80f0603..713e8ef 100644
--- a/tests/test_target_ethos_u_data_analysis.py
+++ b/tests/test_target_ethos_u_data_analysis.py
@@ -12,7 +12,7 @@ from mlia.backend.vela.compat import Operator
from mlia.backend.vela.compat import Operators
from mlia.core.common import DataItem
from mlia.core.data_analysis import Fact
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo
from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityStatus
from mlia.nn.tensorflow.tflite_compat import TFLiteConversionError
diff --git a/tests/test_target_ethos_u_data_collection.py b/tests/test_target_ethos_u_data_collection.py
index fd824ae..6244f8b 100644
--- a/tests/test_target_ethos_u_data_collection.py
+++ b/tests/test_target_ethos_u_data_collection.py
@@ -10,7 +10,7 @@ from mlia.backend.vela.compat import Operators
from mlia.core.context import Context
from mlia.core.data_collection import DataCollector
from mlia.core.errors import FunctionalityNotSupportedError
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
from mlia.target.ethos_u.config import EthosUConfiguration
from mlia.target.ethos_u.data_collection import EthosUOperatorCompatibility
from mlia.target.ethos_u.data_collection import EthosUOptimizationPerformance