From f3e6597dd50ec70f043d692b773f2d9fd31519ae Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Thu, 20 Apr 2023 09:51:20 +0100 Subject: Implement first rewrite (proof of concept) * Define replacement function fully_connected layer * Define RewriteConfiguration and Rewriter to integrate rewrite module into mlia optimize command * Fix a bug in the ethos_u/data_collection.py file * Fix a bug in join.py * Remove diff_stats and use diff instead, added related changes around this to ensure e2e tests passing * Add unit tests for all changes * Fix bug in diff_stats function * The bug was caused by a dividing by numpy array of all zeros. The previous way of handling it did not consider the all zeros case but only dealt with partially zeros * unit tests added. * Fix the bug in rewrite/core/graph_edit/join.py * Remove the possibility of passing None to append_relabel function because it is immutable * The bug happened when empty dictionary was passed in the append_relabel function and the function overwrites the reference of operator_map which caused the dictionary was not updated after the function call Resolves: MLIA-749, MLIA-864, MLIA-866 Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c Signed-off-by: Benjamin Klimczak --- src/mlia/nn/rewrite/core/graph_edit/diff.py | 23 ++++++--- src/mlia/nn/rewrite/core/graph_edit/join.py | 14 +++--- src/mlia/nn/rewrite/core/rewrite.py | 69 +++++++++++++++++++++++++++ src/mlia/nn/rewrite/core/train.py | 1 + src/mlia/nn/rewrite/library/__init__.py | 3 ++ src/mlia/nn/rewrite/library/fc_layer.py | 18 +++++++ src/mlia/nn/select.py | 1 + src/mlia/target/ethos_u/data_collection.py | 3 +- tests/test_cli_commands.py | 17 ++++--- tests/test_nn_rewrite_core_graph_edit_diff.py | 45 +++++++++++++++++ tests/test_nn_rewrite_core_graph_edit_join.py | 25 ++++++++++ tests/test_nn_rewrite_core_rewrite.py | 55 +++++++++++++++++++++ 12 files changed, 252 insertions(+), 22 deletions(-) create mode 100644 src/mlia/nn/rewrite/library/__init__.py create mode 100644 src/mlia/nn/rewrite/library/fc_layer.py create mode 100644 tests/test_nn_rewrite_core_graph_edit_diff.py create mode 100644 tests/test_nn_rewrite_core_rewrite.py diff --git a/src/mlia/nn/rewrite/core/graph_edit/diff.py b/src/mlia/nn/rewrite/core/graph_edit/diff.py index 198e47e..7fa2a72 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/diff.py +++ b/src/mlia/nn/rewrite/core/graph_edit/diff.py @@ -31,6 +31,21 @@ def add_total(name: str, key: str, values: list, totals: dict) -> None: totals[name][key] += values +def _handle_zeros_in_denominator(denominator: np.ndarray) -> np.ndarray: + """Handle zeros in the denominator in nrmse to avoid dividing by zero(s).""" + denominator[denominator == 0.0] = 1.0 + return denominator + + +def calc_nrmse(rmse: dict, dataset1_var: dict) -> dict: + """Divide rmse by target standard deviation.""" + nrmse = { + k: v / _handle_zeros_in_denominator(np.sqrt(dataset1_var[k])) + for k, v in rmse.items() + } + return nrmse + + def diff_stats( file1: str | Path, file2: str | Path, per_tensor_and_channel: bool = False ) -> tuple: @@ -80,14 +95,8 @@ def diff_stats( mse = per_tensor_mean("se") rmse = {k: np.sqrt(v) for k, v in mse.items()} dataset1_var = per_tensor_mean("dataset1_variance") - is_nonzero = {k: dataset1_var[k] > 0 for k in dataset1_var} - # Divide by target standard deviation to get the per-channel nrmse for each - # tensor where possible - nrmse = { - k: v[is_nonzero[k]] / np.sqrt(dataset1_var[k][is_nonzero[k]]) - for k, v in rmse.items() - } + nrmse = calc_nrmse(rmse, dataset1_var) if per_tensor_and_channel: return mae, nrmse diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 14a7347..2530ec8 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -22,8 +22,8 @@ def join_models( input_src: str | Path, input_dst: str | Path, output_file: str | Path, - subgraph_src: SubGraphT = 0, - subgraph_dst: SubGraphT = 0, + subgraph_src: int = 0, + subgraph_dst: int = 0, ) -> None: """Join two models and save the result into a given model file path.""" src_model = load(input_src) @@ -150,12 +150,12 @@ def join_subgraphs( dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs)) -def append_relabel(src: list, dst: list, operator_map: dict | None = None) -> dict: - """Return a map over relabeled tensors in a subgraph.""" - if not operator_map: - operator_map = {} +def append_relabel(src: list, dst: list, operator_map: dict) -> None: + """Update the operator map over relabeled tensors in a subgraph.""" + if operator_map is None: + raise ValueError("The input operator map cannot be None!") + for i, x in enumerate(src): # pylint: disable=invalid-name if i not in operator_map: operator_map[i] = len(dst) dst.append(x) - return operator_map diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index ab34b47..0d182df 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -3,11 +3,19 @@ """Contains class Rewriter to replace a subgraph/layer of a model.""" from __future__ import annotations +import importlib +import tempfile from dataclasses import dataclass from pathlib import Path +from typing import Any +from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration +from mlia.nn.rewrite.core.train import eval_in_dir +from mlia.nn.rewrite.core.train import join_in_dir +from mlia.nn.rewrite.core.train import train +from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel @@ -33,10 +41,71 @@ class Rewriter(Optimizer): """Init Rewriter instance.""" self.model = TFLiteModel(tflite_model_path) self.optimizer_configuration = optimizer_configuration + self.train_dir = "" def apply_optimization(self) -> None: """Apply the rewrite flow.""" + def get_function(arg: str) -> Any: + module_name = ".".join(arg.split(".")[:-1]) + fn_name = arg.split(".")[-1] + module = importlib.import_module(module_name) + return getattr(module, fn_name) + + if self.optimizer_configuration.optimization_target == "fully_connected": + replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model" + else: + raise ConfigurationError( + "Only fully_connected replacement is supported in rewrite module." + ) + + replace_fn = get_function(replace_function) + + augmentation_preset = (None, None) + use_unmodified_model = True + tflite_model = self.model.model_path + tfrecord = str(self.optimizer_configuration.dataset) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_output = Path(tmp_dir, "output.tflite") + + if self.train_dir: + tmp_new = Path(tmp_dir, "new.tflite") + new_part = train_in_dir( + train_dir=self.train_dir, + baseline_dir=None, + output_filename=tmp_new, + replace_fn=replace_fn, + augmentations=augmentation_preset, + steps=32, + learning_rate=1e-3, + batch_size=1, + verbose=True, + show_progress=True, + ) + eval_in_dir(self.train_dir, new_part[0]) + join_in_dir(self.train_dir, new_part[0], str(tmp_output)) + else: + if not self.optimizer_configuration.layers_to_optimize: + raise ConfigurationError( + "Input and output tensor names need to be set for rewrite." + ) + train( + source_model=tflite_model, + unmodified_model=tflite_model if use_unmodified_model else None, + output_model=str(tmp_output), + input_tfrec=str(tfrecord), + replace_fn=replace_fn, + input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], + output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + augment=augmentation_preset, + steps=32, + learning_rate=1e-3, + batch_size=1, + verbose=True, + show_progress=True, + ) + def get_model(self) -> TFLiteModel: """Return optimized model.""" return self.model diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index f837964..c8497a4 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -33,6 +33,7 @@ from mlia.nn.rewrite.core.utils.utils import load from mlia.nn.rewrite.core.utils.utils import save from mlia.utils.logging import log_action + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) diff --git a/src/mlia/nn/rewrite/library/__init__.py b/src/mlia/nn/rewrite/library/__init__.py new file mode 100644 index 0000000..2988554 --- /dev/null +++ b/src/mlia/nn/rewrite/library/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite functions as library.""" diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py new file mode 100644 index 0000000..8704154 --- /dev/null +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Example rewrite with one fully connected layer.""" +from typing import Any + +import tensorflow as tf + + +def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model: + """Generate tflite model for rewrite.""" + input_tensor = tf.keras.layers.Input( + shape=input_shape, name="MbileNet/avg_pool/AvgPool" + ) + output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")( + input_tensor + ) + model = tf.keras.Model(input_tensor, output_tensor) + return model diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 5e223fa..5a7f289 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -135,6 +135,7 @@ def get_optimizer( if isinstance(config, OptimizationSettings): return _get_optimizer(model, cast(OptimizationSettings, config)) + if is_list_of(config, OptimizationSettings): return _get_optimizer(model, cast(List[OptimizationSettings], config)) diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py index 0f3a8d2..ba8b0fe 100644 --- a/src/mlia/target/ethos_u/data_collection.py +++ b/src/mlia/target/ethos_u/data_collection.py @@ -201,7 +201,8 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector): OptimizationSettings( item.get("optimization_type"), # type: ignore item.get("optimization_target"), # type: ignore - item.get("layers_to_optimized"), + item.get("layers_to_optimize"), + item.get("dataset"), ) for item in opt_configuration ] diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 6765a53..e4bbe91 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -73,8 +73,8 @@ def test_performance_unknown_target( None, True, "fully_connected", - "node_a", - "node_b", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", does_not_raise(), ], [ @@ -85,8 +85,8 @@ def test_performance_unknown_target( None, True, "fully_connected", - "node_a", - "node_b", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", pytest.raises( Exception, match=(r"Only 'rewrite' is supported for TensorFlow Lite files."), @@ -157,7 +157,7 @@ def test_performance_unknown_target( ], ], ) -def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments +def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments target_profile: str, sample_context: ExecutionContext, pruning: bool, @@ -171,12 +171,14 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments expected_error: Any, monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, - test_tflite_model: Path, + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, ) -> None: """Test that command should not fail with valid optimization targets.""" mock_performance_estimation(monkeypatch) - model_type = test_tflite_model if rewrite else test_keras_model + model_type = test_tflite_model_fp32 if rewrite else test_keras_model + data = test_tfrecord_fp32 if rewrite else None with expected_error: optimize( @@ -191,6 +193,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments rewrite_target=rewrite_target, rewrite_start=rewrite_start, rewrite_end=rewrite_end, + dataset=data, ) diff --git a/tests/test_nn_rewrite_core_graph_edit_diff.py b/tests/test_nn_rewrite_core_graph_edit_diff.py new file mode 100644 index 0000000..fdda04f --- /dev/null +++ b/tests/test_nn_rewrite_core_graph_edit_diff.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.graph_edit.join.""" +from pathlib import Path + +import numpy as np +import pytest + +from mlia.nn.rewrite.core.graph_edit.diff import calc_nrmse +from mlia.nn.rewrite.core.graph_edit.diff import diff_stats + + +def assert_two_dictionaries_with_numpy_arrays(dict1: dict, dict2: dict) -> None: + """Use numpy assertions to check whether two dictionaries are equal.""" + key1, val1 = list(dict1.keys()), list(dict1.values()) + key2, val2 = list(dict2.keys()), list(dict2.values()) + np.testing.assert_equal(key1, key2) + np.testing.assert_almost_equal(val1, val2) + + +@pytest.mark.parametrize( + "rmse, scale, expected_result", + [ + ( + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + {"test1": np.ndarray((3,), buffer=np.array([4.0, 4.0, 0.0]))}, + {"test1": np.ndarray((3,), buffer=np.array([0.5, 1.0, 3.3]))}, + ), + ( + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + {"test1": np.ndarray((3,), buffer=np.array([0.0, 0.0, 0.0]))}, + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + ), + ], +) +def test_calc_nrmse(rmse: dict, scale: dict, expected_result: dict) -> None: + """Test calc_nrmse() function.""" + assert_two_dictionaries_with_numpy_arrays(calc_nrmse(rmse, scale), expected_result) + + +def test_diff_stats(test_tfrecord_fp32: Path) -> None: + """Test diff_stats() function.""" + res1, res2 = diff_stats(test_tfrecord_fp32, test_tfrecord_fp32) + assert res1 == 0.0 + assert res2 == 0.0 diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py index cbbbeba..cb3e4e2 100644 --- a/tests/test_nn_rewrite_core_graph_edit_join.py +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -1,9 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.graph_edit.join.""" +from contextlib import ExitStack as does_not_raise from pathlib import Path +from typing import Any + +import pytest from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.join import append_relabel from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.utils.utils import load from tests.utils.rewrite import models_are_equal @@ -48,3 +53,23 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: joined_model = load(str(joined_file)) assert models_are_equal(orig_model, joined_model) + + +@pytest.mark.parametrize( + "src, dst, op_map, expected_error", + [ + ([1, 2, 3], [4, 5, 6], {}, does_not_raise()), + ( + [1, 2, 3], + [4, 5, 6], + None, + pytest.raises(Exception, match="The input operator map cannot be None!"), + ), + ], +) +def test_append_relabel( + src: list, dst: list, op_map: dict, expected_error: Any +) -> None: + """Test passing by reference of the object in function append_relabel.""" + with expected_error: + append_relabel(src, dst, op_map) diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py new file mode 100644 index 0000000..b98971e --- /dev/null +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.rewrite.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any + +import pytest + +from mlia.nn.rewrite.core.rewrite import RewriteConfiguration +from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.tensorflow.config import TFLiteModel + + +@pytest.mark.parametrize( + "rewrite_name, expected_error", + [ + ("fully_connected", does_not_raise()), + ("random", does_not_raise()), + ], +) +def test_rewrite_configuration( + test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any +) -> None: + """Test get_rewrite function only supports rewrite type fully_connected.""" + with expected_error: + config_obj = RewriteConfiguration( + rewrite_name, + ["sample_node_start", "sample_node_end"], + None, + ) + + rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) + assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name + assert isinstance(rewriter_obj, Rewriter) + + +def test_rewriter( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, +) -> None: + """Test fc_layer rewrite process with rewrite type fully_connected.""" + config_obj = RewriteConfiguration( + "fully_connected", + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + test_tfrecord_fp32, + ) + + test_obj = Rewriter(test_tflite_model_fp32, config_obj) + test_obj.apply_optimization() + trained_model = test_obj.get_model() + + assert isinstance(trained_model, TFLiteModel) -- cgit v1.2.1