aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/library
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-06-03 09:58:31 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-06-13 13:18:44 +0100
commit09b5122bab771161377321e3f17e05465171ad06 (patch)
tree74676eb296c110a925996448d86cc9dcde28b002 /src/mlia/nn/rewrite/library
parent9896c7e97da38cdaa14953fdce81a29397d1fca3 (diff)
downloadmlia-main.tar.gz
feat: Unstructured Sparsity Rewrites for Fully Connected and Conv2D LayersHEADmain
Adds support for unstructured polynomial decay pruning rewrites Resolves: MLIA-1171 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I9e753f35f8afe53aa24b87d794ff6986a571168f
Diffstat (limited to 'src/mlia/nn/rewrite/library')
-rw-r--r--src/mlia/nn/rewrite/library/sparsity.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py
index 0937b13..1e53254 100644
--- a/src/mlia/nn/rewrite/library/sparsity.py
+++ b/src/mlia/nn/rewrite/library/sparsity.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Rewrite functions used to return layers ready for sparse pruning."""
+from __future__ import annotations
+
from typing import Any
import tensorflow_model_optimization as tfmot
@@ -10,6 +12,78 @@ from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters
from mlia.nn.rewrite.library.helper_functions import get_activation_function
+def fc_sparsity_unstructured_rewrite(
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+) -> keras.Model:
+ """Fully connected TensorFlow Lite model ready for unstructured sparse pruning."""
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
+def conv2d_sparsity_unstructured_rewrite( # pylint: disable=dangerous-default-value
+ input_shape: Any,
+ output_shape: Any,
+ initial_sparsity: float = 0.5,
+ final_sparsity: float = 0.5,
+ begin_step: int = 0,
+ end_step: int = 48000,
+ activation: str = "relu",
+ kernel_size: list[int] = [3, 3],
+) -> keras.Model:
+ """Conv2d TensorFlow Lite model ready for unstructured sparse pruning."""
+ conv2d_parameters = compute_conv2d_parameters(
+ input_shape=input_shape,
+ output_shape=output_shape,
+ kernel_size_input=kernel_size,
+ )
+ activation_function, activation_function_extra_args = get_activation_function(
+ activation
+ )
+ activation_func_found = (
+ [activation_function(**activation_function_extra_args)]
+ if activation_function
+ else []
+ )
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Conv2D(**conv2d_parameters),
+ keras.layers.BatchNormalization(),
+ *activation_func_found,
+ ]
+ ),
+ pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=initial_sparsity,
+ final_sparsity=final_sparsity,
+ begin_step=begin_step,
+ end_step=end_step,
+ ),
+ )
+
+ return model
+
+
def fc_sparsity_rewrite(
input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4
) -> keras.Model: