aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/common.py (renamed from src/mlia/nn/tensorflow/optimizations/common.py)6
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py45
-rw-r--r--src/mlia/nn/select.py (renamed from src/mlia/nn/tensorflow/optimizations/select.py)55
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py6
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py4
5 files changed, 98 insertions, 18 deletions
diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/common.py
index 1dce0b2..5a8d685 100644
--- a/src/mlia/nn/tensorflow/optimizations/common.py
+++ b/src/mlia/nn/common.py
@@ -1,11 +1,11 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Common items for the optimizations module."""
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
-import tensorflow as tf
+from mlia.nn.tensorflow.config import ModelConfiguration
@dataclass
@@ -17,7 +17,7 @@ class Optimizer(ABC):
"""Abstract class for the optimizer."""
@abstractmethod
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> ModelConfiguration:
"""Abstract method to return the model instance from the optimizer."""
@abstractmethod
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
new file mode 100644
index 0000000..d4f61c5
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contains class Rewriter to replace a subgraph/layer of a model."""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from pathlib import Path
+
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
+from mlia.nn.tensorflow.config import TFLiteModel
+
+
+@dataclass
+class RewriteConfiguration(OptimizerConfiguration):
+ """Rewrite configuration."""
+
+ optimization_target: str
+ layers_to_optimize: list[str] | None = None
+ dataset: Path | None = None
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"rewrite: {self.optimization_target}"
+
+
+class Rewriter(Optimizer):
+ """Rewriter class for basic rewrite flow."""
+
+ def __init__(
+ self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration
+ ):
+ """Init Rewriter instance."""
+ self.model = TFLiteModel(tflite_model_path)
+ self.optimizer_configuration = optimizer_configuration
+
+ def apply_optimization(self) -> None:
+ """Apply the rewrite flow."""
+
+ def get_model(self) -> TFLiteModel:
+ """Return optimized model."""
+ return self.model
+
+ def optimization_config(self) -> str:
+ """Optimization configirations."""
diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/select.py
index a78df12..7a25e47 100644
--- a/src/mlia/nn/tensorflow/optimizations/select.py
+++ b/src/mlia/nn/select.py
@@ -4,16 +4,21 @@
from __future__ import annotations
import math
+from pathlib import Path
+from typing import Any
from typing import NamedTuple
import tensorflow as tf
from mlia.core.errors import ConfigurationError
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import Rewriter
from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
-from mlia.nn.tensorflow.optimizations.common import Optimizer
-from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
from mlia.nn.tensorflow.optimizations.pruning import Pruner
from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
from mlia.utils.types import is_list_of
@@ -25,11 +30,13 @@ class OptimizationSettings(NamedTuple):
optimization_type: str
optimization_target: int | float
layers_to_optimize: list[str] | None
+ dataset: Path | None = None
@staticmethod
def create_from(
optimizer_params: list[tuple[str, float]],
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> list[OptimizationSettings]:
"""Create optimization settings from the provided parameters."""
return [
@@ -37,6 +44,7 @@ class OptimizationSettings(NamedTuple):
optimization_type=opt_type,
optimization_target=opt_target,
layers_to_optimize=layers_to_optimize,
+ dataset=dataset,
)
for opt_type, opt_target in optimizer_params
]
@@ -64,6 +72,14 @@ class OptimizationSettings(NamedTuple):
self.optimization_type, next_target, self.layers_to_optimize
)
+ if self.optimization_type == "rewrite":
+ return OptimizationSettings(
+ self.optimization_type,
+ self.optimization_target,
+ self.layers_to_optimize,
+ self.dataset,
+ )
+
raise ValueError(f"Optimization type {self.optimization_type} is unknown.")
@@ -83,7 +99,7 @@ class MultiStageOptimizer(Optimizer):
"""Return string representation of the optimization config."""
return " - ".join(str(opt) for opt in self.optimizations)
- def get_model(self) -> tf.keras.Model:
+ def get_model(self) -> Any:
"""Return optimized model."""
return self.model
@@ -96,19 +112,25 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
- model: tf.keras.Model | KerasModel,
+ model: tf.keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
model = model.get_keras_model()
+ if isinstance(model, TFLiteModel):
+ model = model.model_path
+
if isinstance(config, PruningConfiguration):
return Pruner(model, config)
if isinstance(config, ClusteringConfiguration):
return Clusterer(model, config)
+ if isinstance(config, RewriteConfiguration):
+ return Rewriter(model, config) # type: ignore
+
if isinstance(config, OptimizationSettings) or is_list_of(
config, OptimizationSettings
):
@@ -118,18 +140,18 @@ def get_optimizer(
def _get_optimizer(
- model: tf.keras.Model,
+ model: tf.keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
optimizer_configs = []
- for opt_type, opt_target, layers_to_optimize in optimization_settings:
+ for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings:
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize
+ opt_type, opt_target, layers_to_optimize, dataset
)
optimizer_configs.append(opt_config)
@@ -141,15 +163,16 @@ def _get_optimizer(
def _get_optimizer_configuration(
optimization_type: str,
- optimization_target: int | float,
+ optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
+ dataset: Path | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
opt_type = optimization_type.lower()
if opt_type == "pruning":
- return PruningConfiguration(optimization_target, layers_to_optimize)
+ return PruningConfiguration(float(optimization_target), layers_to_optimize)
if opt_type == "clustering":
# make sure an integer is given as clustering target
@@ -161,11 +184,23 @@ def _get_optimizer_configuration(
f"Optimization target provided: {optimization_target}"
)
+ if opt_type == "rewrite":
+ if isinstance(optimization_target, str):
+ return RewriteConfiguration( # type: ignore
+ str(optimization_target), layers_to_optimize, dataset
+ )
+
+ raise ConfigurationError(
+ "Optimization target should be a string indicating a"
+ "choice from rewrite library. "
+ f"Optimization target provided: {optimization_target}"
+ )
+
raise ConfigurationError(f"Unsupported optimization type: {optimization_type}")
def _check_optimizer_params(
- optimization_type: str, optimization_target: int | float
+ optimization_type: str, optimization_target: int | float | str
) -> None:
"""Check optimizer params."""
if not optimization_target:
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
index 4aaa33e..f9018b3 100644
--- a/src/mlia/nn/tensorflow/optimizations/clustering.py
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
Contains class Clusterer that clusters unique weights per layer to a specified number.
@@ -18,8 +18,8 @@ from tensorflow_model_optimization.python.core.clustering.keras.experimental imp
cluster as experimental_cluster,
)
-from mlia.nn.tensorflow.optimizations.common import Optimizer
-from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
@dataclass
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
index 2d5ef0e..a30b301 100644
--- a/src/mlia/nn/tensorflow/optimizations/pruning.py
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -24,8 +24,8 @@ from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint
pruning_wrapper,
)
-from mlia.nn.tensorflow.optimizations.common import Optimizer
-from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+from mlia.nn.common import Optimizer
+from mlia.nn.common import OptimizerConfiguration
logger = logging.getLogger(__name__)