aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_nn_tensorflow_optimizations_select.py
diff options
context:
space:
mode:
authorDiego Russo <diego.russo@arm.com>2022-05-30 13:34:14 +0100
committerDiego Russo <diego.russo@arm.com>2022-05-30 13:34:14 +0100
commit0efca3cadbad5517a59884576ddb90cfe7ac30f8 (patch)
treeabed6cb6fbf3c439fc8d947f505b6a53d5daeb1e /tests/mlia/test_nn_tensorflow_optimizations_select.py
parent0777092695c143c3a54680b5748287d40c914c35 (diff)
downloadmlia-0efca3cadbad5517a59884576ddb90cfe7ac30f8.tar.gz
Add MLIA codebase0.3.0-rc.1
Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd
Diffstat (limited to 'tests/mlia/test_nn_tensorflow_optimizations_select.py')
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_select.py240
1 files changed, 240 insertions, 0 deletions
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_select.py b/tests/mlia/test_nn_tensorflow_optimizations_select.py
new file mode 100644
index 0000000..5cac8ba
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_select.py
@@ -0,0 +1,240 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module select."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.optimizations.select import get_optimizer
+from mlia.nn.tensorflow.optimizations.select import MultiStageOptimizer
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+@pytest.mark.parametrize(
+ "config, expected_error, expected_type, expected_config",
+ [
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ PruningConfiguration(0.5),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target should be a "
+ "positive integer. "
+ "Optimization target provided: 0.5",
+ ),
+ None,
+ None,
+ ),
+ (
+ ClusteringConfiguration(32),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="superoptimization",
+ optimization_target="supertarget", # type: ignore
+ layers_to_optimize="all", # type: ignore
+ ),
+ pytest.raises(
+ Exception,
+ match="Unsupported optimization type: superoptimization",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization type is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ "wrong_config",
+ pytest.raises(
+ Exception,
+ match="Unknown optimization configuration wrong_config",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=None, # type: ignore
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ does_not_raise(),
+ MultiStageOptimizer,
+ "pruning: 0.5 - clustering: 32",
+ ),
+ ],
+)
+def test_get_optimizer(
+ config: Any,
+ expected_error: Any,
+ expected_type: type,
+ expected_config: str,
+ test_keras_model: Path,
+) -> None:
+ """Test function get_optimzer."""
+ model = tf.keras.models.load_model(str(test_keras_model))
+
+ with expected_error:
+ optimizer = get_optimizer(model, config)
+ assert isinstance(optimizer, expected_type)
+ assert optimizer.optimization_config() == expected_config
+
+
+@pytest.mark.parametrize(
+ "params, expected_result",
+ [
+ (
+ [],
+ [],
+ ),
+ (
+ [("pruning", 0.5)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ [("pruning", 0.5), ("clustering", 32)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ ],
+)
+def test_optimization_settings_create_from(
+ params: List[Tuple[str, float]], expected_result: List[OptimizationSettings]
+) -> None:
+ """Test creating settings from parsed params."""
+ assert OptimizationSettings.create_from(params) == expected_result
+
+
+@pytest.mark.parametrize(
+ "settings, expected_next_target, expected_error",
+ [
+ [
+ OptimizationSettings("clustering", 32, None),
+ OptimizationSettings("clustering", 16, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 4, None),
+ OptimizationSettings("clustering", 4, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 10, None),
+ OptimizationSettings("clustering", 8, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.5, None),
+ OptimizationSettings("pruning", 0.6, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.9, None),
+ OptimizationSettings("pruning", 0.9, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("super_optimization", 42, None),
+ None,
+ pytest.raises(
+ Exception, match="Unknown optimization type super_optimization"
+ ),
+ ],
+ ],
+)
+def test_optimization_settings_next_target(
+ settings: OptimizationSettings,
+ expected_next_target: OptimizationSettings,
+ expected_error: Any,
+) -> None:
+ """Test getting next optimization target."""
+ with expected_error:
+ assert settings.next_target() == expected_next_target