diff options
author | Ruomei Yan <ruomei.yan@arm.com> | 2023-04-20 09:51:20 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:44:51 +0100 |
commit | f3e6597dd50ec70f043d692b773f2d9fd31519ae (patch) | |
tree | 322ccb75e0cc594c57308288cae333a72401979e /src/mlia/nn/rewrite/core/rewrite.py | |
parent | 867f37d643e66c0223457c28f5345f2f21db97f2 (diff) | |
download | mlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz |
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer
* Define RewriteConfiguration and Rewriter to integrate
rewrite module into mlia optimize command
* Fix a bug in the ethos_u/data_collection.py file
* Fix a bug in join.py
* Remove diff_stats and use diff instead, added related
changes around this to ensure e2e tests passing
* Add unit tests for all changes
* Fix bug in diff_stats function
* The bug was caused by a dividing by numpy array
of all zeros. The previous way of handling it
did not consider the all zeros case but only
dealt with partially zeros
* unit tests added.
* Fix the bug in rewrite/core/graph_edit/join.py
* Remove the possibility of passing None to append_relabel
function because it is immutable
* The bug happened when empty dictionary was passed in the
append_relabel function and the function overwrites the
reference of operator_map which caused the dictionary
was not updated after the function call
Resolves: MLIA-749, MLIA-864, MLIA-866
Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index ab34b47..0d182df 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -3,11 +3,19 @@ """Contains class Rewriter to replace a subgraph/layer of a model.""" from __future__ import annotations +import importlib +import tempfile from dataclasses import dataclass from pathlib import Path +from typing import Any +from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration +from mlia.nn.rewrite.core.train import eval_in_dir +from mlia.nn.rewrite.core.train import join_in_dir +from mlia.nn.rewrite.core.train import train +from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel @@ -33,10 +41,71 @@ class Rewriter(Optimizer): """Init Rewriter instance.""" self.model = TFLiteModel(tflite_model_path) self.optimizer_configuration = optimizer_configuration + self.train_dir = "" def apply_optimization(self) -> None: """Apply the rewrite flow.""" + def get_function(arg: str) -> Any: + module_name = ".".join(arg.split(".")[:-1]) + fn_name = arg.split(".")[-1] + module = importlib.import_module(module_name) + return getattr(module, fn_name) + + if self.optimizer_configuration.optimization_target == "fully_connected": + replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model" + else: + raise ConfigurationError( + "Only fully_connected replacement is supported in rewrite module." + ) + + replace_fn = get_function(replace_function) + + augmentation_preset = (None, None) + use_unmodified_model = True + tflite_model = self.model.model_path + tfrecord = str(self.optimizer_configuration.dataset) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_output = Path(tmp_dir, "output.tflite") + + if self.train_dir: + tmp_new = Path(tmp_dir, "new.tflite") + new_part = train_in_dir( + train_dir=self.train_dir, + baseline_dir=None, + output_filename=tmp_new, + replace_fn=replace_fn, + augmentations=augmentation_preset, + steps=32, + learning_rate=1e-3, + batch_size=1, + verbose=True, + show_progress=True, + ) + eval_in_dir(self.train_dir, new_part[0]) + join_in_dir(self.train_dir, new_part[0], str(tmp_output)) + else: + if not self.optimizer_configuration.layers_to_optimize: + raise ConfigurationError( + "Input and output tensor names need to be set for rewrite." + ) + train( + source_model=tflite_model, + unmodified_model=tflite_model if use_unmodified_model else None, + output_model=str(tmp_output), + input_tfrec=str(tfrecord), + replace_fn=replace_fn, + input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], + output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + augment=augmentation_preset, + steps=32, + learning_rate=1e-3, + batch_size=1, + verbose=True, + show_progress=True, + ) + def get_model(self) -> TFLiteModel: """Return optimized model.""" return self.model |