aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuomei Yan <ruomei.yan@arm.com>2023-04-20 09:51:20 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:44:51 +0100
commitf3e6597dd50ec70f043d692b773f2d9fd31519ae (patch)
tree322ccb75e0cc594c57308288cae333a72401979e
parent867f37d643e66c0223457c28f5345f2f21db97f2 (diff)
downloadmlia-f3e6597dd50ec70f043d692b773f2d9fd31519ae.tar.gz
Implement first rewrite (proof of concept)
* Define replacement function fully_connected layer * Define RewriteConfiguration and Rewriter to integrate rewrite module into mlia optimize command * Fix a bug in the ethos_u/data_collection.py file * Fix a bug in join.py * Remove diff_stats and use diff instead, added related changes around this to ensure e2e tests passing * Add unit tests for all changes * Fix bug in diff_stats function * The bug was caused by a dividing by numpy array of all zeros. The previous way of handling it did not consider the all zeros case but only dealt with partially zeros * unit tests added. * Fix the bug in rewrite/core/graph_edit/join.py * Remove the possibility of passing None to append_relabel function because it is immutable * The bug happened when empty dictionary was passed in the append_relabel function and the function overwrites the reference of operator_map which caused the dictionary was not updated after the function call Resolves: MLIA-749, MLIA-864, MLIA-866 Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/diff.py23
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py14
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py69
-rw-r--r--src/mlia/nn/rewrite/core/train.py1
-rw-r--r--src/mlia/nn/rewrite/library/__init__.py3
-rw-r--r--src/mlia/nn/rewrite/library/fc_layer.py18
-rw-r--r--src/mlia/nn/select.py1
-rw-r--r--src/mlia/target/ethos_u/data_collection.py3
-rw-r--r--tests/test_cli_commands.py17
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_diff.py45
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py25
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py55
12 files changed, 252 insertions, 22 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py
index 198e47e..7fa2a72 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/diff.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py
@@ -31,6 +31,21 @@ def add_total(name: str, key: str, values: list, totals: dict) -> None:
totals[name][key] += values
+def _handle_zeros_in_denominator(denominator: np.ndarray) -> np.ndarray:
+ """Handle zeros in the denominator in nrmse to avoid dividing by zero(s)."""
+ denominator[denominator == 0.0] = 1.0
+ return denominator
+
+
+def calc_nrmse(rmse: dict, dataset1_var: dict) -> dict:
+ """Divide rmse by target standard deviation."""
+ nrmse = {
+ k: v / _handle_zeros_in_denominator(np.sqrt(dataset1_var[k]))
+ for k, v in rmse.items()
+ }
+ return nrmse
+
+
def diff_stats(
file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False
) -> tuple:
@@ -80,14 +95,8 @@ def diff_stats(
mse = per_tensor_mean("se")
rmse = {k: np.sqrt(v) for k, v in mse.items()}
dataset1_var = per_tensor_mean("dataset1_variance")
- is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var}
- # Divide by target standard deviation to get the per-channel nrmse for each
- # tensor where possible
- nrmse = {
- k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]])
- for k, v in rmse.items()
- }
+ nrmse = calc_nrmse(rmse, dataset1_var)
if per_tensor_and_channel:
return mae, nrmse
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
index 14a7347..2530ec8 100644
--- a/src/mlia/nn/rewrite/core/graph_edit/join.py
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -22,8 +22,8 @@ def join_models(
input_src: str | Path,
input_dst: str | Path,
output_file: str | Path,
- subgraph_src: SubGraphT = 0,
- subgraph_dst: SubGraphT = 0,
+ subgraph_src: int = 0,
+ subgraph_dst: int = 0,
) -> None:
"""Join two models and save the result into a given model file path."""
src_model = load(input_src)
@@ -150,12 +150,12 @@ def join_subgraphs(
dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
-def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict:
- """Return a map over relabeled tensors in a subgraph."""
- if not operator_map:
- operator_map = {}
+def append_relabel(src: list, dst: list, operator_map: dict) -> None:
+ """Update the operator map over relabeled tensors in a subgraph."""
+ if operator_map is None:
+ raise ValueError("The input operator map cannot be None!")
+
for i, x in enumerate(src): # pylint: disable=invalid-name
if i not in operator_map:
operator_map[i] = len(dst)
dst.append(x)
- return operator_map
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index ab34b47..0d182df 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -3,11 +3,19 @@
"""Contains class Rewriter to replace a subgraph/layer of a model."""
from __future__ import annotations
+import importlib
+import tempfile
from dataclasses import dataclass
from pathlib import Path
+from typing import Any
+from mlia.core.errors import ConfigurationError
from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
+from mlia.nn.rewrite.core.train import eval_in_dir
+from mlia.nn.rewrite.core.train import join_in_dir
+from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
@@ -33,10 +41,71 @@ class Rewriter(Optimizer):
"""Init Rewriter instance."""
self.model = TFLiteModel(tflite_model_path)
self.optimizer_configuration = optimizer_configuration
+ self.train_dir = ""
def apply_optimization(self) -> None:
"""Apply the rewrite flow."""
+ def get_function(arg: str) -> Any:
+ module_name = ".".join(arg.split(".")[:-1])
+ fn_name = arg.split(".")[-1]
+ module = importlib.import_module(module_name)
+ return getattr(module, fn_name)
+
+ if self.optimizer_configuration.optimization_target == "fully_connected":
+ replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model"
+ else:
+ raise ConfigurationError(
+ "Only fully_connected replacement is supported in rewrite module."
+ )
+
+ replace_fn = get_function(replace_function)
+
+ augmentation_preset = (None, None)
+ use_unmodified_model = True
+ tflite_model = self.model.model_path
+ tfrecord = str(self.optimizer_configuration.dataset)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_output = Path(tmp_dir, "output.tflite")
+
+ if self.train_dir:
+ tmp_new = Path(tmp_dir, "new.tflite")
+ new_part = train_in_dir(
+ train_dir=self.train_dir,
+ baseline_dir=None,
+ output_filename=tmp_new,
+ replace_fn=replace_fn,
+ augmentations=augmentation_preset,
+ steps=32,
+ learning_rate=1e-3,
+ batch_size=1,
+ verbose=True,
+ show_progress=True,
+ )
+ eval_in_dir(self.train_dir, new_part[0])
+ join_in_dir(self.train_dir, new_part[0], str(tmp_output))
+ else:
+ if not self.optimizer_configuration.layers_to_optimize:
+ raise ConfigurationError(
+ "Input and output tensor names need to be set for rewrite."
+ )
+ train(
+ source_model=tflite_model,
+ unmodified_model=tflite_model if use_unmodified_model else None,
+ output_model=str(tmp_output),
+ input_tfrec=str(tfrecord),
+ replace_fn=replace_fn,
+ input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
+ output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ augment=augmentation_preset,
+ steps=32,
+ learning_rate=1e-3,
+ batch_size=1,
+ verbose=True,
+ show_progress=True,
+ )
+
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
return self.model
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index f837964..c8497a4 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -33,6 +33,7 @@ from mlia.nn.rewrite.core.utils.utils import load
from mlia.nn.rewrite.core.utils.utils import save
from mlia.utils.logging import log_action
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
diff --git a/src/mlia/nn/rewrite/library/__init__.py b/src/mlia/nn/rewrite/library/__init__.py
new file mode 100644
index 0000000..2988554
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Rewrite functions as library."""
diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py
new file mode 100644
index 0000000..8704154
--- /dev/null
+++ b/src/mlia/nn/rewrite/library/fc_layer.py
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Example rewrite with one fully connected layer."""
+from typing import Any
+
+import tensorflow as tf
+
+
+def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model:
+ """Generate tflite model for rewrite."""
+ input_tensor = tf.keras.layers.Input(
+ shape=input_shape, name="MbileNet/avg_pool/AvgPool"
+ )
+ output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")(
+ input_tensor
+ )
+ model = tf.keras.Model(input_tensor, output_tensor)
+ return model
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 5e223fa..5a7f289 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -135,6 +135,7 @@ def get_optimizer(
if isinstance(config, OptimizationSettings):
return _get_optimizer(model, cast(OptimizationSettings, config))
+
if is_list_of(config, OptimizationSettings):
return _get_optimizer(model, cast(List[OptimizationSettings], config))
diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py
index 0f3a8d2..ba8b0fe 100644
--- a/src/mlia/target/ethos_u/data_collection.py
+++ b/src/mlia/target/ethos_u/data_collection.py
@@ -201,7 +201,8 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector):
OptimizationSettings(
item.get("optimization_type"), # type: ignore
item.get("optimization_target"), # type: ignore
- item.get("layers_to_optimized"),
+ item.get("layers_to_optimize"),
+ item.get("dataset"),
)
for item in opt_configuration
]
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py
index 6765a53..e4bbe91 100644
--- a/tests/test_cli_commands.py
+++ b/tests/test_cli_commands.py
@@ -73,8 +73,8 @@ def test_performance_unknown_target(
None,
True,
"fully_connected",
- "node_a",
- "node_b",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
does_not_raise(),
],
[
@@ -85,8 +85,8 @@ def test_performance_unknown_target(
None,
True,
"fully_connected",
- "node_a",
- "node_b",
+ "sequential/flatten/Reshape",
+ "StatefulPartitionedCall:0",
pytest.raises(
Exception,
match=(r"Only 'rewrite' is supported for TensorFlow Lite files."),
@@ -157,7 +157,7 @@ def test_performance_unknown_target(
],
],
)
-def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
+def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments
target_profile: str,
sample_context: ExecutionContext,
pruning: bool,
@@ -171,12 +171,14 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
expected_error: Any,
monkeypatch: pytest.MonkeyPatch,
test_keras_model: Path,
- test_tflite_model: Path,
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
) -> None:
"""Test that command should not fail with valid optimization targets."""
mock_performance_estimation(monkeypatch)
- model_type = test_tflite_model if rewrite else test_keras_model
+ model_type = test_tflite_model_fp32 if rewrite else test_keras_model
+ data = test_tfrecord_fp32 if rewrite else None
with expected_error:
optimize(
@@ -191,6 +193,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments
rewrite_target=rewrite_target,
rewrite_start=rewrite_start,
rewrite_end=rewrite_end,
+ dataset=data,
)
diff --git a/tests/test_nn_rewrite_core_graph_edit_diff.py b/tests/test_nn_rewrite_core_graph_edit_diff.py
new file mode 100644
index 0000000..fdda04f
--- /dev/null
+++ b/tests/test_nn_rewrite_core_graph_edit_diff.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.graph_edit.join."""
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from mlia.nn.rewrite.core.graph_edit.diff import calc_nrmse
+from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
+
+
+def assert_two_dictionaries_with_numpy_arrays(dict1: dict, dict2: dict) -> None:
+ """Use numpy assertions to check whether two dictionaries are equal."""
+ key1, val1 = list(dict1.keys()), list(dict1.values())
+ key2, val2 = list(dict2.keys()), list(dict2.values())
+ np.testing.assert_equal(key1, key2)
+ np.testing.assert_almost_equal(val1, val2)
+
+
+@pytest.mark.parametrize(
+ "rmse, scale, expected_result",
+ [
+ (
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ {"test1": np.ndarray((3,), buffer=np.array([4.0, 4.0, 0.0]))},
+ {"test1": np.ndarray((3,), buffer=np.array([0.5, 1.0, 3.3]))},
+ ),
+ (
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ {"test1": np.ndarray((3,), buffer=np.array([0.0, 0.0, 0.0]))},
+ {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))},
+ ),
+ ],
+)
+def test_calc_nrmse(rmse: dict, scale: dict, expected_result: dict) -> None:
+ """Test calc_nrmse() function."""
+ assert_two_dictionaries_with_numpy_arrays(calc_nrmse(rmse, scale), expected_result)
+
+
+def test_diff_stats(test_tfrecord_fp32: Path) -> None:
+ """Test diff_stats() function."""
+ res1, res2 = diff_stats(test_tfrecord_fp32, test_tfrecord_fp32)
+ assert res1 == 0.0
+ assert res2 == 0.0
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cbbbeba..cb3e4e2 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -1,9 +1,14 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.graph_edit.join."""
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
+from typing import Any
+
+import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.utils.utils import load
from tests.utils.rewrite import models_are_equal
@@ -48,3 +53,23 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
joined_model = load(str(joined_file))
assert models_are_equal(orig_model, joined_model)
+
+
+@pytest.mark.parametrize(
+ "src, dst, op_map, expected_error",
+ [
+ ([1, 2, 3], [4, 5, 6], {}, does_not_raise()),
+ (
+ [1, 2, 3],
+ [4, 5, 6],
+ None,
+ pytest.raises(Exception, match="The input operator map cannot be None!"),
+ ),
+ ],
+)
+def test_append_relabel(
+ src: list, dst: list, op_map: dict, expected_error: Any
+) -> None:
+ """Test passing by reference of the object in function append_relabel."""
+ with expected_error:
+ append_relabel(src, dst, op_map)
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
new file mode 100644
index 0000000..b98971e
--- /dev/null
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -0,0 +1,55 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.core.rewrite."""
+from __future__ import annotations
+
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.tensorflow.config import TFLiteModel
+
+
+@pytest.mark.parametrize(
+ "rewrite_name, expected_error",
+ [
+ ("fully_connected", does_not_raise()),
+ ("random", does_not_raise()),
+ ],
+)
+def test_rewrite_configuration(
+ test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
+) -> None:
+ """Test get_rewrite function only supports rewrite type fully_connected."""
+ with expected_error:
+ config_obj = RewriteConfiguration(
+ rewrite_name,
+ ["sample_node_start", "sample_node_end"],
+ None,
+ )
+
+ rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
+ assert isinstance(rewriter_obj, Rewriter)
+
+
+def test_rewriter(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+) -> None:
+ """Test fc_layer rewrite process with rewrite type fully_connected."""
+ config_obj = RewriteConfiguration(
+ "fully_connected",
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ test_tfrecord_fp32,
+ )
+
+ test_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ test_obj.apply_optimization()
+ trained_model = test_obj.get_model()
+
+ assert isinstance(trained_model, TFLiteModel)