aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_select.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-02-15 14:50:58 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-14 15:45:40 +0000
commit0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch)
tree09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /tests/test_nn_select.py
parent09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff)
downloadmlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization Resolves: MLIA-1004 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'tests/test_nn_select.py')
-rw-r--r--tests/test_nn_select.py69
1 files changed, 66 insertions, 3 deletions
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index 31628d2..92b7a3d 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -1,16 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module select."""
from __future__ import annotations
from contextlib import ExitStack as does_not_raise
+from dataclasses import asdict
from pathlib import Path
from typing import Any
+from typing import cast
import pytest
import tensorflow as tf
from mlia.core.errors import ConfigurationError
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.select import get_optimizer
from mlia.nn.select import MultiStageOptimizer
from mlia.nn.select import OptimizationSettings
@@ -135,6 +140,23 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
MultiStageOptimizer,
"pruning: 0.5 - clustering: 32",
),
+ (
+ OptimizationSettings(
+ optimization_type="rewrite",
+ optimization_target="fully_connected", # type: ignore
+ layers_to_optimize=None,
+ dataset=None,
+ ),
+ does_not_raise(),
+ RewritingOptimizer,
+ "rewrite: fully_connected",
+ ),
+ (
+ RewriteConfiguration("fully_connected"),
+ does_not_raise(),
+ RewritingOptimizer,
+ "rewrite: fully_connected",
+ ),
],
)
def test_get_optimizer(
@@ -143,17 +165,58 @@ def test_get_optimizer(
expected_type: type,
expected_config: str,
test_keras_model: Path,
+ test_tflite_model: Path,
) -> None:
"""Test function get_optimzer."""
- model = tf.keras.models.load_model(str(test_keras_model))
-
with expected_error:
+ if (
+ isinstance(config, OptimizationSettings)
+ and config.optimization_type == "rewrite"
+ ) or isinstance(config, RewriteConfiguration):
+ model = test_tflite_model
+ else:
+ model = tf.keras.models.load_model(str(test_keras_model))
optimizer = get_optimizer(model, config)
assert isinstance(optimizer, expected_type)
assert optimizer.optimization_config() == expected_config
@pytest.mark.parametrize(
+ "rewrite_parameters",
+ [[None], [{"batch_size": 64, "learning_rate": 0.003}]],
+)
+@pytest.mark.skip_set_training_steps
+def test_get_optimizer_training_parameters(
+ rewrite_parameters: list[dict], test_tflite_model: Path
+) -> None:
+ """Test function get_optimzer with various combinations of parameters."""
+ config = OptimizationSettings(
+ optimization_type="rewrite",
+ optimization_target="fully_connected", # type: ignore
+ layers_to_optimize=None,
+ dataset=None,
+ )
+ optimizer = cast(
+ RewritingOptimizer,
+ get_optimizer(test_tflite_model, config, list(rewrite_parameters)),
+ )
+
+ assert len(rewrite_parameters) == 1
+
+ assert isinstance(
+ optimizer.optimizer_configuration.train_params, TrainingParameters
+ )
+ if not rewrite_parameters[0]:
+ assert asdict(TrainingParameters()) == asdict(
+ optimizer.optimizer_configuration.train_params
+ )
+ else:
+ assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict(
+ optimizer.optimizer_configuration.train_params
+ )
+
+
+@pytest.mark.parametrize(
"params, expected_result",
[
(