aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py32
1 files changed, 31 insertions, 1 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 487784d..363d614 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module mlia.nn.rewrite.core.rewrite."""
from __future__ import annotations
@@ -7,6 +7,7 @@ from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
from typing import cast
+from unittest.mock import MagicMock
import pytest
@@ -16,6 +17,8 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
+from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
@@ -129,3 +132,30 @@ def test_rewrite_function_autoload_fail() -> None:
"Unable to load rewrite function 'invalid_module.invalid_function'"
" for 'mock_rewrite'."
)
+
+
+def test_rewrite_configuration_train_params(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test if we pass training parameters to the
+ rewrite configuration function they are passed to train_in_dir."""
+ train_params = TrainingParameters(
+ batch_size=64, steps=24000, learning_rate=1e-5, show_progress=True
+ )
+
+ config_obj = RewriteConfiguration(
+ "fully_connected",
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ test_tfrecord_fp32,
+ train_params=train_params,
+ )
+
+ rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ train_mock = MagicMock(side_effect=train_in_dir)
+ monkeypatch.setattr("mlia.nn.rewrite.core.train.train_in_dir", train_mock)
+ rewriter_obj.apply_optimization()
+
+ train_mock.assert_called_once()
+ assert train_mock.call_args.kwargs["train_params"] == train_params