aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/options.py
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2023-02-20 15:32:54 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:42:28 +0100
commit446c379c92e15ad8f24ed0db853dd0fc9c271151 (patch)
treefb9e2b20fba15d3aa44054eb76d76fbdb1459006 /src/mlia/cli/options.py
parentf0b8ed75fed9dc69ab1f6313339f9f7e38bfc725 (diff)
downloadmlia-446c379c92e15ad8f24ed0db853dd0fc9c271151.tar.gz
Add a CLI component to enable rewrites
* Add flags for rewrite (--rewrite, --rewrite-start, --rewrite-end, --rewrite-target) * Refactor CLI interfaces to accept tflite models with optimize for rewrite, keras models with optimize for clustering and pruning * Refactor and move common.py and select.py out of the folder nn/tensorflow/optimizations * Add file nn/rewrite/core/rewrite.py as placeholder * Update/add unit tests * Refactor OptimizeModel in ethos_u/data_collection.py for accepting tflite model case * Extend the logic so that if "--rewrite" is specified, we don't add pruning to also accept TFLite models. * Update README.md Resolves: MLIA-750, MLIA-854, MLIA-865 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I67d85f71fa253d2bad4efe304ad8225970b9622c
Diffstat (limited to 'src/mlia/cli/options.py')
-rw-r--r--src/mlia/cli/options.py64
1 files changed, 62 insertions, 2 deletions
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index fe177eb..7b3b373 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,6 +12,7 @@ from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
from mlia.backend.manager import get_available_backends
from mlia.core.common import AdviceCategory
+from mlia.core.errors import ConfigurationError
from mlia.core.typing import OutputFormat
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
@@ -90,6 +91,10 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
)
multi_optimization_group.add_argument(
+ "--rewrite", action="store_true", help="Apply rewrite optimization."
+ )
+
+ multi_optimization_group.add_argument(
"--pruning-target",
type=float,
help="Sparsity to be reached during optimization "
@@ -103,6 +108,24 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
f"(default: {DEFAULT_CLUSTERING_TARGET})",
)
+ multi_optimization_group.add_argument(
+ "--rewrite-target",
+ type=str,
+ help="Type of rewrite to apply to the subgraph/layer.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-start",
+ type=str,
+ help="Starting node in the graph of the subgraph to be rewritten.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-end",
+ type=str,
+ help="Ending node in the graph of the subgraph to be rewritten.",
+ )
+
def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
@@ -131,6 +154,16 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None:
)
+def add_dataset_options(parser: argparse.ArgumentParser) -> None:
+ """Addd dataset options."""
+ dataset_group = parser.add_argument_group("dataset options")
+ dataset_group.add_argument(
+ "--dataset",
+ type=Path,
+ help="The path of input tfrec file",
+ )
+
+
def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
model_group = parser.add_argument_group("Keras model options")
@@ -239,12 +272,17 @@ def add_output_directory(parser: argparse.ArgumentParser) -> None:
)
-def parse_optimization_parameters(
+def parse_optimization_parameters( # pylint: disable=too-many-arguments
pruning: bool = False,
clustering: bool = False,
pruning_target: float | None = None,
clustering_target: int | None = None,
+ rewrite: bool | None = False,
+ rewrite_target: str | None = None,
+ rewrite_start: str | None = None,
+ rewrite_end: str | None = None,
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
opt_types = []
@@ -263,7 +301,14 @@ def parse_optimization_parameters(
if not clustering_target:
clustering_target = DEFAULT_CLUSTERING_TARGET
- if (pruning is False and clustering is False) or pruning:
+ if rewrite:
+ if not rewrite_target or not rewrite_start or not rewrite_end:
+ raise ConfigurationError(
+ "To perform rewrite, rewrite-target, rewrite-start and "
+ "rewrite-end must be set."
+ )
+
+ if not any((pruning, clustering, rewrite)) or pruning:
opt_types.append("pruning")
opt_targets.append(pruning_target)
@@ -276,10 +321,25 @@ def parse_optimization_parameters(
"optimization_type": opt_type.strip(),
"optimization_target": float(opt_target),
"layers_to_optimize": layers_to_optimize,
+ "dataset": dataset,
}
for opt_type, opt_target in zip(opt_types, opt_targets)
]
+ if rewrite:
+ if rewrite_target not in ["remove", "fully_connected"]:
+ raise ConfigurationError(
+ "Currently only remove and fully_connected are supported."
+ )
+ optimizer_params.append(
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": rewrite_target,
+ "layers_to_optimize": [rewrite_start, rewrite_end],
+ "dataset": dataset,
+ }
+ )
+
return optimizer_params