aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2023-04-20 09:51:20 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:44:51 +0100
commitf3e6597dd50ec70f043d692b773f2d9fd31519ae (patch)
tree322ccb75e0cc594c57308288cae333a72401979e /src/mlia/nn/rewrite/core/rewrite.py
parent867f37d643e66c0223457c28f5345f2f21db97f2 (diff)
downloadmlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer * Define RewriteConfiguration and Rewriter to integrate rewrite module into mlia optimize command * Fix a bug in the ethos_u/data_collection.py file * Fix a bug in join.py * Remove diff_stats and use diff instead, added related changes around this to ensure e2e tests passing * Add unit tests for all changes * Fix bug in diff_stats function * The bug was caused by a dividing by numpy array of all zeros. The previous way of handling it did not consider the all zeros case but only dealt with partially zeros * unit tests added. * Fix the bug in rewrite/core/graph_edit/join.py * Remove the possibility of passing None to append_relabel function because it is immutable * The bug happened when empty dictionary was passed in the append_relabel function and the function overwrites the reference of operator_map which caused the dictionary was not updated after the function call Resolves: MLIA-749, MLIA-864, MLIA-866 Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index ab34b47..0d182df 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -3,11 +3,19 @@
"""Contains class Rewriter to replace a subgraph/layer of a model."""
from __future__ import annotations
+import importlib
+import tempfile
from dataclasses import dataclass
from pathlib import Path
+from typing import Any
+from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
+from mlia.nn.rewrite.core.train import eval_in_dir
+from mlia.nn.rewrite.core.train import join_in_dir
+from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
@@ -33,10 +41,71 @@ class Rewriter(Optimizer):
"""Init Rewriter instance."""
self.model = TFLiteModel(tflite_model_path)
self.optimizer_configuration = optimizer_configuration
+ self.train_dir = ""
def apply_optimization(self) -> None:
"""Apply the rewrite flow."""
+ def get_function(arg: str) -> Any:
+ module_name = ".".join(arg.split(".")[:-1])
+ fn_name = arg.split(".")[-1]
+ module = importlib.import_module(module_name)
+ return getattr(module, fn_name)
+
+ if self.optimizer_configuration.optimization_target == "fully_connected":
+ replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model"
+ else:
+ raise ConfigurationError(
+ "Only fully_connected replacement is supported in rewrite module."
+ )
+
+ replace_fn = get_function(replace_function)
+
+ augmentation_preset = (None, None)
+ use_unmodified_model = True
+ tflite_model = self.model.model_path
+ tfrecord = str(self.optimizer_configuration.dataset)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_output = Path(tmp_dir, "output.tflite")
+
+ if self.train_dir:
+ tmp_new = Path(tmp_dir, "new.tflite")
+ new_part = train_in_dir(
+ train_dir=self.train_dir,
+ baseline_dir=None,
+ output_filename=tmp_new,
+ replace_fn=replace_fn,
+ augmentations=augmentation_preset,
+ steps=32,
+ learning_rate=1e-3,
+ batch_size=1,
+ verbose=True,
+ show_progress=True,
+ )
+ eval_in_dir(self.train_dir, new_part[0])
+ join_in_dir(self.train_dir, new_part[0], str(tmp_output))
+ else:
+ if not self.optimizer_configuration.layers_to_optimize:
+ raise ConfigurationError(
+ "Input and output tensor names need to be set for rewrite."
+ )
+ train(
+ source_model=tflite_model,
+ unmodified_model=tflite_model if use_unmodified_model else None,
+ output_model=str(tmp_output),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fn,
+ input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
+ output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ augment=augmentation_preset,
+ steps=32,
+ learning_rate=1e-3,
+ batch_size=1,
+ verbose=True,
+ show_progress=True,
+ )
+
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
return self.model