aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli')
-rw-r--r--src/mlia/cli/commands.py12
-rw-r--r--src/mlia/cli/helpers.py8
-rw-r--r--src/mlia/cli/main.py2
-rw-r--r--src/mlia/cli/options.py64
4 files changed, 81 insertions, 5 deletions
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
index 1f339ee..7af41d9 100644
--- a/src/mlia/cli/commands.py
+++ b/src/mlia/cli/commands.py
@@ -96,7 +96,7 @@ def check(
)
-def optimize( # pylint: disable=too-many-arguments
+def optimize( # pylint: disable=too-many-locals,too-many-arguments
ctx: ExecutionContext,
target_profile: str,
model: str,
@@ -104,8 +104,13 @@ def optimize( # pylint: disable=too-many-arguments
clustering: bool,
pruning_target: float | None,
clustering_target: int | None,
+ rewrite: bool | None = None,
+ rewrite_target: str | None = None,
+ rewrite_start: str | None = None,
+ rewrite_end: str | None = None,
layers_to_optimize: list[str] | None = None,
backend: list[str] | None = None,
+ dataset: Path | None = None,
) -> None:
"""Show the performance improvements (if any) after applying the optimizations.
@@ -145,7 +150,12 @@ def optimize( # pylint: disable=too-many-arguments
clustering,
pruning_target,
clustering_target,
+ rewrite,
+ rewrite_target,
+ rewrite_start,
+ rewrite_end,
layers_to_optimize,
+ dataset,
)
)
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
index abc6df0..824db1b 100644
--- a/src/mlia/cli/helpers.py
+++ b/src/mlia/cli/helpers.py
@@ -10,7 +10,7 @@ from typing import cast
from mlia.cli.options import get_target_profile_opts
from mlia.core.helpers import ActionResolver
-from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.target.config import get_builtin_profile_path
from mlia.target.config import is_builtin_profile
@@ -47,7 +47,11 @@ class CLIActionResolver(ActionResolver):
) -> list[str]:
"""Return specific optimization command description."""
opt_types = " ".join("--" + opt.optimization_type for opt in opt_settings)
- opt_targs_strings = ["--pruning-target", "--clustering-target"]
+ opt_targs_strings = [
+ "--pruning-target",
+ "--clustering-target",
+ "--rewrite-target",
+ ]
opt_targs = ",".join(
f"{opt_targs_strings[i]} {opt.optimization_target}"
for i, opt in enumerate(opt_settings)
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 88258d5..9e1b7cd 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -23,6 +23,7 @@ from mlia.cli.options import add_backend_install_options
from mlia.cli.options import add_backend_options
from mlia.cli.options import add_backend_uninstall_options
from mlia.cli.options import add_check_category_options
+from mlia.cli.options import add_dataset_options
from mlia.cli.options import add_debug_options
from mlia.cli.options import add_keras_model_options
from mlia.cli.options import add_model_options
@@ -89,6 +90,7 @@ def get_commands() -> list[CommandInfo]:
add_multi_optimization_options,
add_output_options,
add_debug_options,
+ add_dataset_options,
],
),
]
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
index fe177eb..7b3b373 100644
--- a/src/mlia/cli/options.py
+++ b/src/mlia/cli/options.py
@@ -12,6 +12,7 @@ from typing import Sequence
from mlia.backend.corstone import is_corstone_backend
from mlia.backend.manager import get_available_backends
from mlia.core.common import AdviceCategory
+from mlia.core.errors import ConfigurationError
from mlia.core.typing import OutputFormat
from mlia.target.registry import builtin_profile_names
from mlia.target.registry import registry as target_registry
@@ -90,6 +91,10 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
)
multi_optimization_group.add_argument(
+ "--rewrite", action="store_true", help="Apply rewrite optimization."
+ )
+
+ multi_optimization_group.add_argument(
"--pruning-target",
type=float,
help="Sparsity to be reached during optimization "
@@ -103,6 +108,24 @@ def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
f"(default: {DEFAULT_CLUSTERING_TARGET})",
)
+ multi_optimization_group.add_argument(
+ "--rewrite-target",
+ type=str,
+ help="Type of rewrite to apply to the subgraph/layer.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-start",
+ type=str,
+ help="Starting node in the graph of the subgraph to be rewritten.",
+ )
+
+ multi_optimization_group.add_argument(
+ "--rewrite-end",
+ type=str,
+ help="Ending node in the graph of the subgraph to be rewritten.",
+ )
+
def add_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
@@ -131,6 +154,16 @@ def add_debug_options(parser: argparse.ArgumentParser) -> None:
)
+def add_dataset_options(parser: argparse.ArgumentParser) -> None:
+ """Addd dataset options."""
+ dataset_group = parser.add_argument_group("dataset options")
+ dataset_group.add_argument(
+ "--dataset",
+ type=Path,
+ help="The path of input tfrec file",
+ )
+
+
def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
"""Add model specific options."""
model_group = parser.add_argument_group("Keras model options")
@@ -239,12 +272,17 @@ def add_output_directory(parser: argparse.ArgumentParser) -> None:
)
-def parse_optimization_parameters(
+def parse_optimization_parameters( # pylint: disable=too-many-arguments
pruning: bool = False,
clustering: bool = False,
pruning_target: float | None = None,
clustering_target: int | None = None,
+ rewrite: bool | None = False,
+ rewrite_target: str | None = None,
+ rewrite_start: str | None = None,
+ rewrite_end: str | None = None,
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> list[dict[str, Any]]:
"""Parse provided optimization parameters."""
opt_types = []
@@ -263,7 +301,14 @@ def parse_optimization_parameters(
if not clustering_target:
clustering_target = DEFAULT_CLUSTERING_TARGET
- if (pruning is False and clustering is False) or pruning:
+ if rewrite:
+ if not rewrite_target or not rewrite_start or not rewrite_end:
+ raise ConfigurationError(
+ "To perform rewrite, rewrite-target, rewrite-start and "
+ "rewrite-end must be set."
+ )
+
+ if not any((pruning, clustering, rewrite)) or pruning:
opt_types.append("pruning")
opt_targets.append(pruning_target)
@@ -276,10 +321,25 @@ def parse_optimization_parameters(
"optimization_type": opt_type.strip(),
"optimization_target": float(opt_target),
"layers_to_optimize": layers_to_optimize,
+ "dataset": dataset,
}
for opt_type, opt_target in zip(opt_types, opt_targets)
]
+ if rewrite:
+ if rewrite_target not in ["remove", "fully_connected"]:
+ raise ConfigurationError(
+ "Currently only remove and fully_connected are supported."
+ )
+ optimizer_params.append(
+ {
+ "optimization_type": "rewrite",
+ "optimization_target": rewrite_target,
+ "layers_to_optimize": [rewrite_start, rewrite_end],
+ "dataset": dataset,
+ }
+ )
+
return optimizer_params