diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 487784d..363d614 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations @@ -7,6 +7,7 @@ from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast +from unittest.mock import MagicMock import pytest @@ -16,6 +17,8 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters @@ -129,3 +132,30 @@ def test_rewrite_function_autoload_fail() -> None: "Unable to load rewrite function 'invalid_module.invalid_function'" " for 'mock_rewrite'." ) + + +def test_rewrite_configuration_train_params( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test if we pass training parameters to the + rewrite configuration function they are passed to train_in_dir.""" + train_params = TrainingParameters( + batch_size=64, steps=24000, learning_rate=1e-5, show_progress=True + ) + + config_obj = RewriteConfiguration( + "fully_connected", + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + test_tfrecord_fp32, + train_params=train_params, + ) + + rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + train_mock = MagicMock(side_effect=train_in_dir) + monkeypatch.setattr("mlia.nn.rewrite.core.train.train_in_dir", train_mock) + rewriter_obj.apply_optimization() + + train_mock.assert_called_once() + assert train_mock.call_args.kwargs["train_params"] == train_params |