diff options
author | Ruomei Yan <ruomei.yan@arm.com> | 2023-02-20 15:32:54 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:42:28 +0100 |
commit | 446c379c92e15ad8f24ed0db853dd0fc9c271151 (patch) | |
tree | fb9e2b20fba15d3aa44054eb76d76fbdb1459006 /tests/test_nn_select.py | |
parent | f0b8ed75fed9dc69ab1f6313339f9f7e38bfc725 (diff) | |
download | mlia-446c379c92e15ad8f24ed0db853dd0fc9c271151.tar.gz |
Add a CLI component to enable rewrites
* Add flags for rewrite (--rewrite, --rewrite-start,
--rewrite-end, --rewrite-target)
* Refactor CLI interfaces to accept tflite models with optimize for
rewrite, keras models with optimize for clustering and pruning
* Refactor and move common.py and select.py out of the folder
nn/tensorflow/optimizations
* Add file nn/rewrite/core/rewrite.py as placeholder
* Update/add unit tests
* Refactor OptimizeModel in ethos_u/data_collection.py
for accepting tflite model case
* Extend the logic so that if "--rewrite" is specified, we don't add
pruning to also accept TFLite models.
* Update README.md
Resolves: MLIA-750, MLIA-854, MLIA-865
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Change-Id: I67d85f71fa253d2bad4efe304ad8225970b9622c
Diffstat (limited to 'tests/test_nn_select.py')
-rw-r--r-- | tests/test_nn_select.py | 241 |
1 files changed, 241 insertions, 0 deletions
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py new file mode 100644 index 0000000..31628d2 --- /dev/null +++ b/tests/test_nn_select.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module select.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any + +import pytest +import tensorflow as tf + +from mlia.core.errors import ConfigurationError +from mlia.nn.select import get_optimizer +from mlia.nn.select import MultiStageOptimizer +from mlia.nn.select import OptimizationSettings +from mlia.nn.tensorflow.optimizations.clustering import Clusterer +from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration +from mlia.nn.tensorflow.optimizations.pruning import Pruner +from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration + + +@pytest.mark.parametrize( + "config, expected_error, expected_type, expected_config", + [ + ( + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + does_not_raise(), + Pruner, + "pruning: 0.5", + ), + ( + PruningConfiguration(0.5), + does_not_raise(), + Pruner, + "pruning: 0.5", + ), + ( + OptimizationSettings( + optimization_type="clustering", + optimization_target=32, + layers_to_optimize=None, + ), + does_not_raise(), + Clusterer, + "clustering: 32", + ), + ( + OptimizationSettings( + optimization_type="clustering", + optimization_target=0.5, + layers_to_optimize=None, + ), + pytest.raises( + ConfigurationError, + match="Optimization target should be a " + "positive integer. " + "Optimization target provided: 0.5", + ), + None, + None, + ), + ( + ClusteringConfiguration(32), + does_not_raise(), + Clusterer, + "clustering: 32", + ), + ( + OptimizationSettings( + optimization_type="superoptimization", + optimization_target="supertarget", # type: ignore + layers_to_optimize="all", # type: ignore + ), + pytest.raises( + ConfigurationError, + match="Unsupported optimization type: superoptimization", + ), + None, + None, + ), + ( + OptimizationSettings( + optimization_type="", + optimization_target=0.5, + layers_to_optimize=None, + ), + pytest.raises( + ConfigurationError, + match="Optimization type is not provided", + ), + None, + None, + ), + ( + "wrong_config", + pytest.raises( + Exception, + match="Unknown optimization configuration wrong_config", + ), + None, + None, + ), + ( + OptimizationSettings( + optimization_type="pruning", + optimization_target=None, # type: ignore + layers_to_optimize=None, + ), + pytest.raises( + Exception, + match="Optimization target is not provided", + ), + None, + None, + ), + ( + [ + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + OptimizationSettings( + optimization_type="clustering", + optimization_target=32, + layers_to_optimize=None, + ), + ], + does_not_raise(), + MultiStageOptimizer, + "pruning: 0.5 - clustering: 32", + ), + ], +) +def test_get_optimizer( + config: Any, + expected_error: Any, + expected_type: type, + expected_config: str, + test_keras_model: Path, +) -> None: + """Test function get_optimzer.""" + model = tf.keras.models.load_model(str(test_keras_model)) + + with expected_error: + optimizer = get_optimizer(model, config) + assert isinstance(optimizer, expected_type) + assert optimizer.optimization_config() == expected_config + + +@pytest.mark.parametrize( + "params, expected_result", + [ + ( + [], + [], + ), + ( + [("pruning", 0.5)], + [ + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ) + ], + ), + ( + [("pruning", 0.5), ("clustering", 32)], + [ + OptimizationSettings( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + OptimizationSettings( + optimization_type="clustering", + optimization_target=32, + layers_to_optimize=None, + ), + ], + ), + ], +) +def test_optimization_settings_create_from( + params: list[tuple[str, float]], expected_result: list[OptimizationSettings] +) -> None: + """Test creating settings from parsed params.""" + assert OptimizationSettings.create_from(params) == expected_result + + +@pytest.mark.parametrize( + "settings, expected_next_target, expected_error", + [ + [ + OptimizationSettings("clustering", 32, None), + OptimizationSettings("clustering", 16, None), + does_not_raise(), + ], + [ + OptimizationSettings("clustering", 4, None), + OptimizationSettings("clustering", 4, None), + does_not_raise(), + ], + [ + OptimizationSettings("clustering", 10, None), + OptimizationSettings("clustering", 8, None), + does_not_raise(), + ], + [ + OptimizationSettings("pruning", 0.5, None), + OptimizationSettings("pruning", 0.6, None), + does_not_raise(), + ], + [ + OptimizationSettings("pruning", 0.9, None), + OptimizationSettings("pruning", 0.9, None), + does_not_raise(), + ], + [ + OptimizationSettings("super_optimization", 42, None), + None, + pytest.raises( + Exception, match="Optimization type super_optimization is unknown." + ), + ], + ], +) +def test_optimization_settings_next_target( + settings: OptimizationSettings, + expected_next_target: OptimizationSettings, + expected_error: Any, +) -> None: + """Test getting next optimization target.""" + with expected_error: + assert settings.next_target() == expected_next_target |