From f3e6597dd50ec70f043d692b773f2d9fd31519ae Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Thu, 20 Apr 2023 09:51:20 +0100 Subject: Implement first rewrite (proof of concept) * Define replacement function fully_connected layer * Define RewriteConfiguration and Rewriter to integrate rewrite module into mlia optimize command * Fix a bug in the ethos_u/data_collection.py file * Fix a bug in join.py * Remove diff_stats and use diff instead, added related changes around this to ensure e2e tests passing * Add unit tests for all changes * Fix bug in diff_stats function * The bug was caused by a dividing by numpy array of all zeros. The previous way of handling it did not consider the all zeros case but only dealt with partially zeros * unit tests added. * Fix the bug in rewrite/core/graph_edit/join.py * Remove the possibility of passing None to append_relabel function because it is immutable * The bug happened when empty dictionary was passed in the append_relabel function and the function overwrites the reference of operator_map which caused the dictionary was not updated after the function call Resolves: MLIA-749, MLIA-864, MLIA-866 Change-Id: I1ab426996232f182345e6e98033d5dcb32aea08c Signed-off-by: Benjamin Klimczak --- tests/test_cli_commands.py | 17 +++++---- tests/test_nn_rewrite_core_graph_edit_diff.py | 45 ++++++++++++++++++++++ tests/test_nn_rewrite_core_graph_edit_join.py | 25 ++++++++++++ tests/test_nn_rewrite_core_rewrite.py | 55 +++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 7 deletions(-) create mode 100644 tests/test_nn_rewrite_core_graph_edit_diff.py create mode 100644 tests/test_nn_rewrite_core_rewrite.py (limited to 'tests') diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 6765a53..e4bbe91 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -73,8 +73,8 @@ def test_performance_unknown_target( None, True, "fully_connected", - "node_a", - "node_b", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", does_not_raise(), ], [ @@ -85,8 +85,8 @@ def test_performance_unknown_target( None, True, "fully_connected", - "node_a", - "node_b", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", pytest.raises( Exception, match=(r"Only 'rewrite' is supported for TensorFlow Lite files."), @@ -157,7 +157,7 @@ def test_performance_unknown_target( ], ], ) -def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments +def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments target_profile: str, sample_context: ExecutionContext, pruning: bool, @@ -171,12 +171,14 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments expected_error: Any, monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, - test_tflite_model: Path, + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, ) -> None: """Test that command should not fail with valid optimization targets.""" mock_performance_estimation(monkeypatch) - model_type = test_tflite_model if rewrite else test_keras_model + model_type = test_tflite_model_fp32 if rewrite else test_keras_model + data = test_tfrecord_fp32 if rewrite else None with expected_error: optimize( @@ -191,6 +193,7 @@ def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments rewrite_target=rewrite_target, rewrite_start=rewrite_start, rewrite_end=rewrite_end, + dataset=data, ) diff --git a/tests/test_nn_rewrite_core_graph_edit_diff.py b/tests/test_nn_rewrite_core_graph_edit_diff.py new file mode 100644 index 0000000..fdda04f --- /dev/null +++ b/tests/test_nn_rewrite_core_graph_edit_diff.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.graph_edit.join.""" +from pathlib import Path + +import numpy as np +import pytest + +from mlia.nn.rewrite.core.graph_edit.diff import calc_nrmse +from mlia.nn.rewrite.core.graph_edit.diff import diff_stats + + +def assert_two_dictionaries_with_numpy_arrays(dict1: dict, dict2: dict) -> None: + """Use numpy assertions to check whether two dictionaries are equal.""" + key1, val1 = list(dict1.keys()), list(dict1.values()) + key2, val2 = list(dict2.keys()), list(dict2.values()) + np.testing.assert_equal(key1, key2) + np.testing.assert_almost_equal(val1, val2) + + +@pytest.mark.parametrize( + "rmse, scale, expected_result", + [ + ( + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + {"test1": np.ndarray((3,), buffer=np.array([4.0, 4.0, 0.0]))}, + {"test1": np.ndarray((3,), buffer=np.array([0.5, 1.0, 3.3]))}, + ), + ( + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + {"test1": np.ndarray((3,), buffer=np.array([0.0, 0.0, 0.0]))}, + {"test1": np.ndarray((3,), buffer=np.array([1.0, 2.0, 3.3]))}, + ), + ], +) +def test_calc_nrmse(rmse: dict, scale: dict, expected_result: dict) -> None: + """Test calc_nrmse() function.""" + assert_two_dictionaries_with_numpy_arrays(calc_nrmse(rmse, scale), expected_result) + + +def test_diff_stats(test_tfrecord_fp32: Path) -> None: + """Test diff_stats() function.""" + res1, res2 = diff_stats(test_tfrecord_fp32, test_tfrecord_fp32) + assert res1 == 0.0 + assert res2 == 0.0 diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py index cbbbeba..cb3e4e2 100644 --- a/tests/test_nn_rewrite_core_graph_edit_join.py +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -1,9 +1,14 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.graph_edit.join.""" +from contextlib import ExitStack as does_not_raise from pathlib import Path +from typing import Any + +import pytest from mlia.nn.rewrite.core.graph_edit.cut import cut_model +from mlia.nn.rewrite.core.graph_edit.join import append_relabel from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.utils.utils import load from tests.utils.rewrite import models_are_equal @@ -48,3 +53,23 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: joined_model = load(str(joined_file)) assert models_are_equal(orig_model, joined_model) + + +@pytest.mark.parametrize( + "src, dst, op_map, expected_error", + [ + ([1, 2, 3], [4, 5, 6], {}, does_not_raise()), + ( + [1, 2, 3], + [4, 5, 6], + None, + pytest.raises(Exception, match="The input operator map cannot be None!"), + ), + ], +) +def test_append_relabel( + src: list, dst: list, op_map: dict, expected_error: Any +) -> None: + """Test passing by reference of the object in function append_relabel.""" + with expected_error: + append_relabel(src, dst, op_map) diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py new file mode 100644 index 0000000..b98971e --- /dev/null +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.rewrite.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any + +import pytest + +from mlia.nn.rewrite.core.rewrite import RewriteConfiguration +from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.tensorflow.config import TFLiteModel + + +@pytest.mark.parametrize( + "rewrite_name, expected_error", + [ + ("fully_connected", does_not_raise()), + ("random", does_not_raise()), + ], +) +def test_rewrite_configuration( + test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any +) -> None: + """Test get_rewrite function only supports rewrite type fully_connected.""" + with expected_error: + config_obj = RewriteConfiguration( + rewrite_name, + ["sample_node_start", "sample_node_end"], + None, + ) + + rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) + assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name + assert isinstance(rewriter_obj, Rewriter) + + +def test_rewriter( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, +) -> None: + """Test fc_layer rewrite process with rewrite type fully_connected.""" + config_obj = RewriteConfiguration( + "fully_connected", + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + test_tfrecord_fp32, + ) + + test_obj = Rewriter(test_tflite_model_fp32, config_obj) + test_obj.apply_optimization() + trained_model = test_obj.get_model() + + assert isinstance(trained_model, TFLiteModel) -- cgit v1.2.1