# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module select.""" from __future__ import annotations from contextlib import ExitStack as does_not_raise from dataclasses import asdict from pathlib import Path from typing import Any from typing import cast import pytest from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.select import get_optimizer from mlia.nn.select import MultiStageOptimizer from mlia.nn.select import OptimizationSettings from mlia.nn.tensorflow.optimizations.clustering import Clusterer from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration from mlia.nn.tensorflow.optimizations.pruning import Pruner from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration @pytest.mark.parametrize( "config, expected_error, expected_type, expected_config", [ ( OptimizationSettings( optimization_type="pruning", optimization_target=0.5, layers_to_optimize=None, ), does_not_raise(), Pruner, "pruning: 0.5", ), ( PruningConfiguration(0.5), does_not_raise(), Pruner, "pruning: 0.5", ), ( OptimizationSettings( optimization_type="clustering", optimization_target=32, layers_to_optimize=None, ), does_not_raise(), Clusterer, "clustering: 32", ), ( OptimizationSettings( optimization_type="clustering", optimization_target=0.5, layers_to_optimize=None, ), pytest.raises( ConfigurationError, match="Optimization target should be a " "positive integer. " "Optimization target provided: 0.5", ), None, None, ), ( ClusteringConfiguration(32), does_not_raise(), Clusterer, "clustering: 32", ), ( OptimizationSettings( optimization_type="superoptimization", optimization_target="supertarget", # type: ignore layers_to_optimize="all", # type: ignore ), pytest.raises( ConfigurationError, match="Unsupported optimization type: superoptimization", ), None, None, ), ( OptimizationSettings( optimization_type="", optimization_target=0.5, layers_to_optimize=None, ), pytest.raises( ConfigurationError, match="Optimization type is not provided", ), None, None, ), ( "wrong_config", pytest.raises( Exception, match="Unknown optimization configuration wrong_config", ), None, None, ), ( OptimizationSettings( optimization_type="pruning", optimization_target=None, # type: ignore layers_to_optimize=None, ), pytest.raises( Exception, match="Optimization target is not provided", ), None, None, ), ( [ OptimizationSettings( optimization_type="pruning", optimization_target=0.5, layers_to_optimize=None, ), OptimizationSettings( optimization_type="clustering", optimization_target=32, layers_to_optimize=None, ), ], does_not_raise(), MultiStageOptimizer, "pruning: 0.5 - clustering: 32", ), ( OptimizationSettings( optimization_type="rewrite", optimization_target="fully-connected", # type: ignore layers_to_optimize=None, dataset=None, ), does_not_raise(), RewritingOptimizer, "rewrite: fully-connected", ), ( RewriteConfiguration("fully-connected"), does_not_raise(), RewritingOptimizer, "rewrite: fully-connected", ), ], ) def test_get_optimizer( config: Any, expected_error: Any, expected_type: type, expected_config: str, test_keras_model: Path, test_tflite_model: Path, ) -> None: """Test function get_optimzer.""" with expected_error: if ( isinstance(config, OptimizationSettings) and config.optimization_type == "rewrite" ) or isinstance(config, RewriteConfiguration): model = test_tflite_model else: model = keras.models.load_model(str(test_keras_model)) optimizer = get_optimizer(model, config) assert isinstance(optimizer, expected_type) assert optimizer.optimization_config() == expected_config @pytest.mark.parametrize( "rewrite_parameters", [None, {"batch_size": 64, "learning_rate": 0.003}], ) @pytest.mark.skip_set_training_steps def test_get_optimizer_training_parameters( rewrite_parameters: dict | None, test_tflite_model: Path ) -> None: """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( optimization_type="rewrite", optimization_target="fully-connected", # type: ignore layers_to_optimize=None, dataset=None, ) optimizer = cast( RewritingOptimizer, get_optimizer(test_tflite_model, config, rewrite_parameters), ) assert isinstance( optimizer.optimizer_configuration.train_params, TrainingParameters ) if not rewrite_parameters: assert asdict(TrainingParameters()) == asdict( optimizer.optimizer_configuration.train_params ) else: assert asdict(TrainingParameters()) | rewrite_parameters == asdict( optimizer.optimizer_configuration.train_params ) @pytest.mark.parametrize( "params, expected_result", [ ( [], [], ), ( [("pruning", 0.5)], [ OptimizationSettings( optimization_type="pruning", optimization_target=0.5, layers_to_optimize=None, ) ], ), ( [("pruning", 0.5), ("clustering", 32)], [ OptimizationSettings( optimization_type="pruning", optimization_target=0.5, layers_to_optimize=None, ), OptimizationSettings( optimization_type="clustering", optimization_target=32, layers_to_optimize=None, ), ], ), ], ) def test_optimization_settings_create_from( params: list[tuple[str, float]], expected_result: list[OptimizationSettings] ) -> None: """Test creating settings from parsed params.""" assert OptimizationSettings.create_from(params) == expected_result @pytest.mark.parametrize( "settings, expected_next_target, expected_error", [ [ OptimizationSettings("clustering", 32, None), OptimizationSettings("clustering", 16, None), does_not_raise(), ], [ OptimizationSettings("clustering", 4, None), OptimizationSettings("clustering", 4, None), does_not_raise(), ], [ OptimizationSettings("clustering", 10, None), OptimizationSettings("clustering", 8, None), does_not_raise(), ], [ OptimizationSettings("pruning", 0.5, None), OptimizationSettings("pruning", 0.6, None), does_not_raise(), ], [ OptimizationSettings("pruning", 0.9, None), OptimizationSettings("pruning", 0.9, None), does_not_raise(), ], [ OptimizationSettings("super_optimization", 42, None), None, pytest.raises( Exception, match="Optimization type super_optimization is unknown." ), ], ], ) def test_optimization_settings_next_target( settings: OptimizationSettings, expected_next_target: OptimizationSettings, expected_error: Any, ) -> None: """Test getting next optimization target.""" with expected_error: assert settings.next_target() == expected_next_target