aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/options.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli/options.py')
-rw-r--r--src/mlia/cli/options.py64
1 files changed, 62 insertions, 2 deletions
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index fe177eb..7b3b373 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,6 +12,7 @@ from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
from mlia.backend.manager import get_available_backends
from mlia.core.common import AdviceCategory
+from mlia.core.errors import ConfigurationError
from mlia.core.typing import OutputFormat
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
@@ -90,6 +91,10 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
)
multi_optimization_group.add_argument(
+ "--rewrite", action="store_true", help="Apply rewrite optimization."
+ )
+
+ multi_optimization_group.add_argument(
"--pruning-target",
type=float,
help="Sparsity to be reached during optimization "
@@ -103,6 +108,24 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
f"(default: {DEFAULT_CLUSTERING_TARGET})",
)
+ multi_optimization_group.add_argument(
+ "--rewrite-target",
+ type=str,
+ help="Type of rewrite to apply to the subgraph/layer.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-start",
+ type=str,
+ help="Starting node in the graph of the subgraph to be rewritten.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-end",
+ type=str,
+ help="Ending node in the graph of the subgraph to be rewritten.",
+ )
+
def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
@@ -131,6 +154,16 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None:
)
+def add_dataset_options(parser: argparse.ArgumentParser) -> None:
+ """Addd dataset options."""
+ dataset_group = parser.add_argument_group("dataset options")
+ dataset_group.add_argument(
+ "--dataset",
+ type=Path,
+ help="The path of input tfrec file",
+ )
+
+
def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
model_group = parser.add_argument_group("Keras model options")
@@ -239,12 +272,17 @@ def add_output_directory(parser: argparse.ArgumentParser) -> None:
)
-def parse_optimization_parameters(
+def parse_optimization_parameters( # pylint: disable=too-many-arguments
pruning: bool = False,
clustering: bool = False,
pruning_target: float | None = None,
clustering_target: int | None = None,
+ rewrite: bool | None = False,
+ rewrite_target: str | None = None,
+ rewrite_start: str | None = None,
+ rewrite_end: str | None = None,
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
opt_types = []
@@ -263,7 +301,14 @@ def parse_optimization_parameters(
if not clustering_target:
clustering_target = DEFAULT_CLUSTERING_TARGET
- if (pruning is False and clustering is False) or pruning:
+ if rewrite:
+ if not rewrite_target or not rewrite_start or not rewrite_end:
+ raise ConfigurationError(
+ "To perform rewrite, rewrite-target, rewrite-start and "
+ "rewrite-end must be set."
+ )
+
+ if not any((pruning, clustering, rewrite)) or pruning:
opt_types.append("pruning")
opt_targets.append(pruning_target)
@@ -276,10 +321,25 @@ def parse_optimization_parameters(
"optimization_type": opt_type.strip(),
"optimization_target": float(opt_target),
"layers_to_optimize": layers_to_optimize,
+ "dataset": dataset,
}
for opt_type, opt_target in zip(opt_types, opt_targets)
]
+ if rewrite:
+ if rewrite_target not in ["remove", "fully_connected"]:
+ raise ConfigurationError(
+ "Currently only remove and fully_connected are supported."
+ )
+ optimizer_params.append(
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": rewrite_target,
+ "layers_to_optimize": [rewrite_start, rewrite_end],
+ "dataset": dataset,
+ }
+ )
+
return optimizer_params