diff options
Diffstat (limited to 'src/mlia/nn/rewrite/library')
-rw-r--r-- | src/mlia/nn/rewrite/library/clustering.py | 5 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_layer.py | 18 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/helper_functions.py | 2 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/layers.py | 53 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/sparsity.py | 80 |
5 files changed, 135 insertions, 23 deletions
diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py index 48914dc..9247457 100644 --- a/src/mlia/nn/rewrite/library/clustering.py +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -45,6 +45,7 @@ def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value ), activation: str = "relu", kernel_size: list[int] = [3, 3], + layer_type: type[keras.layers.Layer] = keras.layers.Conv2D, ) -> keras.Model: """Conv2d TensorFlow Lite model ready for clustering.""" rewrite_params = { @@ -59,7 +60,7 @@ def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value activation_function, activation_function_extra_args = get_activation_function( activation ) - activation_func_found = ( + activation_func_found = ( # pylint: disable=duplicate-code [activation_function(**activation_function_extra_args)] if activation_function else [] @@ -68,7 +69,7 @@ def conv2d_clustering_rewrite( # pylint: disable=dangerous-default-value to_cluster=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Conv2D(**conv2d_parameters), + layer_type(**conv2d_parameters), keras.layers.BatchNormalization(), *activation_func_found, ] diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py deleted file mode 100644 index 92195d1..0000000 --- a/src/mlia/nn/rewrite/library/fc_layer.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Rewrite function used to return regular layers.""" -from typing import Any - -from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 - - -def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: - """Fully connected TensorFlow Lite model for rewrite.""" - model = keras.Sequential( - ( - keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Reshape([-1]), - keras.layers.Dense(output_shape), - ) - ) - return model diff --git a/src/mlia/nn/rewrite/library/helper_functions.py b/src/mlia/nn/rewrite/library/helper_functions.py index 1237c17..fd5993b 100644 --- a/src/mlia/nn/rewrite/library/helper_functions.py +++ b/src/mlia/nn/rewrite/library/helper_functions.py @@ -45,7 +45,7 @@ def compute_conv2d_parameters( # pylint: disable=dangerous-default-value assert len(input_shape) == 3 assert len(output_shape) == 3 kernel_size = tuple(kernel_size_input) - num_filters = (output_shape[-1] - input_shape[-1]) + input_shape[-1] + num_filters = output_shape[-1] padding = "valid" stride_h = round(input_shape[0] / output_shape[0]) check_output_size_h = math.floor((input_shape[0] - kernel_size[0]) / stride_h) + 1 diff --git a/src/mlia/nn/rewrite/library/layers.py b/src/mlia/nn/rewrite/library/layers.py new file mode 100644 index 0000000..abf0a4c --- /dev/null +++ b/src/mlia/nn/rewrite/library/layers.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Rewrite function used to return regular layers.""" +from typing import Any + +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def fc_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: + """Fully connected TensorFlow Lite model for rewrite.""" + model = keras.Sequential( + ( + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ) + ) + return model + + +def conv2d_rewrite( # pylint: disable=dangerous-default-value + input_shape: Any, + output_shape: Any, + activation: str = "relu", + kernel_size: list[int] = [3, 3], + layer_type: type[keras.layers.Layer] = keras.layers.Conv2D, +) -> keras.Model: + """Fully connected TensorFlow Lite model for rewrite.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + activation_func_found = ( # pylint: disable=duplicate-code + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) + model = keras.Sequential( + ( + keras.layers.InputLayer(input_shape=input_shape), + layer_type(**conv2d_parameters), + keras.layers.BatchNormalization(), + *activation_func_found, + ) + ) + return model diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 0937b13..5102094 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Rewrite functions used to return layers ready for sparse pruning.""" +from __future__ import annotations + from typing import Any import tensorflow_model_optimization as tfmot @@ -10,6 +12,79 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters from mlia.nn.rewrite.library.helper_functions import get_activation_function +def fc_sparsity_unstructured_rewrite( + input_shape: Any, + output_shape: Any, + initial_sparsity: float = 0.5, + final_sparsity: float = 0.5, + begin_step: int = 0, + end_step: int = 48000, +) -> keras.Model: + """Fully connected TensorFlow Lite model ready for unstructured sparse pruning.""" + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + pruning_schedule=tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=initial_sparsity, + final_sparsity=final_sparsity, + begin_step=begin_step, + end_step=end_step, + ), + ) + + return model + + +def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value, too-many-arguments + input_shape: Any, + output_shape: Any, + initial_sparsity: float = 0.5, + final_sparsity: float = 0.5, + begin_step: int = 0, + end_step: int = 48000, + activation: str = "relu", + kernel_size: list[int] = [3, 3], + layer_type: type[keras.layers.Layer] = keras.layers.Conv2D, +) -> keras.Model: + """Conv2d TensorFlow Lite model ready for unstructured sparse pruning.""" + conv2d_parameters = compute_conv2d_parameters( + input_shape=input_shape, + output_shape=output_shape, + kernel_size_input=kernel_size, + ) + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + activation_func_found = ( + [activation_function(**activation_function_extra_args)] + if activation_function + else [] + ) + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + layer_type(**conv2d_parameters), + keras.layers.BatchNormalization(), + *activation_func_found, + ] + ), + pruning_schedule=tfmot.sparsity.keras.PolynomialDecay( + initial_sparsity=initial_sparsity, + final_sparsity=final_sparsity, + begin_step=begin_step, + end_step=end_step, + ), + ) + + return model + + def fc_sparsity_rewrite( input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4 ) -> keras.Model: @@ -38,6 +113,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value sparsity_n: int = 4, activation: str = "relu", kernel_size: list[int] = [3, 3], + layer_type: type[keras.layers.Layer] = keras.layers.Conv2D, ) -> keras.Model: """Conv2d TensorFlow Lite model ready for sparse pruning.""" conv2d_parameters = compute_conv2d_parameters( @@ -48,7 +124,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value activation_function, activation_function_extra_args = get_activation_function( activation ) - activation_func_found = ( + activation_func_found = ( # pylint: disable=duplicate-code [activation_function(**activation_function_extra_args)] if activation_function else [] @@ -57,7 +133,7 @@ def conv2d_sparsity_rewrite( # pylint: disable=dangerous-default-value to_prune=keras.Sequential( [ keras.layers.InputLayer(input_shape=input_shape), - keras.layers.Conv2D(**conv2d_parameters), + layer_type(**conv2d_parameters), keras.layers.BatchNormalization(), *activation_func_found, ] |