From 17813ba5be09f0e11fc0748afa4ccf2da02881b6 Mon Sep 17 00:00:00 2001 From: Madeleine Dunn Date: Mon, 13 Nov 2023 15:40:21 +0000 Subject: feat: Implement fp32 sparsity 2:4 rewrite - Update the existing placeholder with code to prune the given model Resolves: MLIA-1002 Signed-off-by: Madeleine Dunn Change-Id: I76b0e0bfe81be5e57d518cd7bb588eef76a11641 --- src/mlia/nn/rewrite/core/rewrite.py | 55 ++++++++++++++++++---- src/mlia/nn/rewrite/core/train.py | 19 ++++---- src/mlia/nn/rewrite/library/fc_sparsity24_layer.py | 21 ++++++--- tests/test_nn_rewrite_core_rewrite.py | 16 ++++++- tests/test_nn_rewrite_core_train.py | 9 +++- 5 files changed, 94 insertions(+), 26 deletions(-) diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index a4d47c4..2a7b432 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -12,6 +12,7 @@ from typing import Any from typing import Callable from typing import cast +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError @@ -22,6 +23,10 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite +from mlia.nn.rewrite.library.fc_sparsity24_layer import ( + get_keras_model as fc_rewrite_sparsity24, +) from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry @@ -45,6 +50,43 @@ class Rewrite: except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex + def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: + """Return a quantized model if required.""" + if model_is_quantized: + model = tfmot.quantization.keras.quantize_model(model) + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + +class PruningRewrite(Rewrite): + """Derived Rewrite class with pruning-specific logic.""" + + pruning_callback = tfmot.sparsity.keras.UpdatePruningStep + + strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) + + def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: + """Return a quantized model if required.""" + if model_is_quantized: + # placeholder for PQAT + pass + return model + + def training_callbacks(self) -> list: + """Return pruning-specific rewrite callback.""" + return [self.pruning_callback()] + + def post_process(self, model: keras.Model) -> keras.Model: + """Pruning-specific post-processing rewrite options.""" + return self.strip_pruning_wrapper(model) + @dataclass class DynamicallyLoadedRewrite(Rewrite): @@ -113,14 +155,8 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - DynamicallyLoadedRewrite( - "fully-connected", - "mlia.nn.rewrite.library.fc_layer.get_keras_model", - ), - DynamicallyLoadedRewrite( - "fully-connected-sparsity24", - "mlia.nn.rewrite.library.fc_sparsity24_layer.get_keras_model24", - ), + Rewrite("fully-connected", fc_rewrite), + PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), ] ) @@ -142,7 +178,6 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -161,7 +196,7 @@ class RewritingOptimizer(Optimizer): unmodified_model=tflite_model if use_unmodified_model else None, output_model=str(tmp_output), input_tfrec=str(tfrecord), - replace_fn=rewrite, + rewrite=rewrite, input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index e0b3c75..89de880 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -22,7 +22,6 @@ from typing import Literal import numpy as np import tensorflow as tf -import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from numpy.random import Generator @@ -78,7 +77,7 @@ def train( unmodified_model: Any, output_model: str, input_tfrec: str, - replace_fn: Callable, + rewrite: Callable, input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), @@ -118,7 +117,7 @@ def train( train_dir=train_dir, baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), - replace_fn=replace_fn, + rewrite=rewrite, train_params=train_params, ) @@ -345,7 +344,7 @@ def train_in_dir( train_dir: str, baseline_dir: Any, output_filename: Path, - replace_fn: Callable, + rewrite: Callable, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -381,13 +380,12 @@ def train_in_dir( input_shape = teacher.shape_from_name[input_name][1:] output_shape = teacher.shape_from_name[output_name][1:] - - model = replace_fn(input_shape, output_shape) + model = rewrite(input_shape, output_shape) optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) + + model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined] model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) logger.info(model.summary()) @@ -428,6 +426,8 @@ def train_in_dir( elif train_params.learning_rate_schedule == "constant": callbacks = [] + callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] + output_filenames = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps @@ -463,6 +463,9 @@ def train_in_dir( with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if steps_so_far == train_params.steps: + model = rewrite.post_process(model) # type: ignore[attr-defined] + save_as_tflite( model, checkpoint_filename, diff --git a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py index 1b17522..531b34a 100644 --- a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py +++ b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py @@ -1,14 +1,23 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""PLACEHOLDER for example rewrite with one fully connected 2:4 sparsity layer.""" +"""Example rewrite with one fully connected 2:4 sparsity layer.""" from typing import Any +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -from .fc_layer import get_keras_model - -def get_keras_model24(input_shape: Any, output_shape: Any) -> keras.Model: +def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model: """Generate TensorFlow Lite model for rewrite.""" - model = get_keras_model(input_shape, output_shape) + model = tfmot.sparsity.keras.prune_low_magnitude( + to_prune=keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Reshape([-1]), + keras.layers.Dense(output_shape), + ] + ), + sparsity_m_by_n=(2, 4), + ) + return model diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index 1d0100a..e614cad 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -38,6 +38,20 @@ def test_rewrite() -> None: rewrite((1, 2), (1, 2)) +@pytest.mark.parametrize( + "rewrite_name, callbacks_length", + [ + ("fully-connected", 0), + ("fully-connected-sparsity24", 1), + ], +) +def test_rewrite_selection(rewrite_name: str, callbacks_length: int) -> None: + """Test that the correct rewrite class is instantiated.""" + rewrite = RewritingOptimizer.registry.items[rewrite_name] + assert rewrite.name == rewrite_name + assert len(rewrite.training_callbacks()) == callbacks_length + + @pytest.mark.parametrize( "rewrite_name, expected_error", [ @@ -89,7 +103,7 @@ def test_rewriting_optimizer( def test_register_rewrite_function() -> None: - """Test adding rewrite functions and verify the are reported via the registry.""" + """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 624c5ed..34b9543 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -14,6 +14,7 @@ import pytest import tensorflow as tf from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule @@ -53,12 +54,18 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") + mock_rewrite = DynamicallyLoadedRewrite( + name="replace", + function_name=( + "tests.test_nn_rewrite_core_train.replace_fully_connected_with_conv" + ), + ) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, output_model=str(output_file), input_tfrec=str(tfrecord), - replace_fn=replace_fully_connected_with_conv, + rewrite=mock_rewrite, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, -- cgit v1.2.1