aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-08 14:08:06 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-04-16 13:11:31 +0100
commit32405c279d2f98c2d40bdbbb7f7306ff12c86cd6 (patch)
tree42781ca219b822a9ec9f212a9ee516f65b184a27
parent427e02696f1ede596ef6dce82787a37e122efa78 (diff)
downloadmlia-32405c279d2f98c2d40bdbbb7f7306ff12c86cd6.tar.gz
feat: Implement the clustering rewrite for int8
Implements a clustering rewrite for fully connected layers for int8 models Resolves: MLIA-1080 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: If48efb22764187a382e5b84bbb5c3b75a6e71b75
-rw-r--r--setup.cfg2
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py132
-rw-r--r--src/mlia/nn/rewrite/core/train.py118
-rw-r--r--src/mlia/nn/rewrite/library/fc_clustering_layer.py4
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py91
-rw-r--r--tests/test_nn_rewrite_core_train.py12
6 files changed, 264 insertions, 95 deletions
diff --git a/setup.cfg b/setup.cfg
index 6ddb576..0714caf 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright (c) 2020 Troy Comi
# SPDX-License-Identifier: Apache-2.0 AND MIT
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 6a3695a..e2c097c 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -12,6 +12,7 @@ from pathlib import Path
from typing import Any
from typing import Callable
+import numpy as np
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
@@ -53,9 +54,9 @@ class Rewrite(ABC):
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
- @abstractmethod
def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
+ return model
@abstractmethod
def training_callbacks(self) -> list:
@@ -65,60 +66,41 @@ class Rewrite(ABC):
def post_process(self, model: keras.Model) -> keras.Model:
"""Return default post-processing rewrite options."""
+ @abstractmethod
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Check if the optimization has produced the correct result."""
-class ClusteringRewrite(Rewrite):
- """Graph clustering rewrite logic to be used by RewritingOptimizer."""
- strip_pruning_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+class GenericRewrite(Rewrite):
+ """Graph rewrite logic for fully-connected rewrite."""
def quantize(self, model: keras.Model) -> keras.Model:
- """Return a quantized model."""
- return model
-
- def post_process(self, model: keras.Model) -> keras.Model:
- """Return the clustering stripped model."""
- return self.strip_pruning_wrapper(model)
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
def training_callbacks(self) -> list:
"""Return default rewrite callbacks."""
return []
-
-class QATRewrite(Rewrite):
- """Logic for rewrites requiring quantization-aware training."""
-
- def pruning_preserved_quantization(
- self,
- model: keras.Model,
- ) -> keras.Model:
- """Apply pruning-preserved quantization training to a given model."""
- model = tfmot.quantization.keras.quantize_annotate_model(model)
- model = tfmot.quantization.keras.quantize_apply(
- model,
- tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
- )
-
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
return model
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
-class FullyConnectedRewrite(Rewrite):
- """Graph rewrite logic for fully-connected rewrite."""
-
- def quantize(self, model: keras.Model) -> keras.Model:
- """Return a quantized model if required."""
- model = tfmot.quantization.keras.quantize_model(model)
- return model
- def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
- return []
+class QuantizeAwareTrainingRewrite(Rewrite, ABC):
+ """Abstract class for rewrites that perform QAT."""
- def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ @abstractmethod
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply optimization-aware quantization to a given model."""
return model
-class Sparsity24Rewrite(QATRewrite):
+class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
"""Graph rewrite logic for fully-connected-sparsity24 rewrite."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
@@ -137,6 +119,74 @@ class Sparsity24Rewrite(QATRewrite):
"""Pruning-specific post-processing rewrite options."""
return self.strip_pruning_wrapper(model)
+ def preserved_quantize(
+ self,
+ model: keras.Model,
+ ) -> keras.Model:
+ """Apply pruning-preserved quantization training to a given model."""
+ model = tfmot.quantization.keras.quantize_annotate_model(model)
+ model = tfmot.quantization.keras.quantize_apply(
+ model,
+ tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
+ )
+
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
+
+
+class ClusteringRewrite(QuantizeAwareTrainingRewrite):
+ """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+
+ _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply clustering-preserved quantization to a given model."""
+ quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model)
+ cqat_model = tfmot.quantization.keras.quantize_apply(
+ quant_aware_model,
+ tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(),
+ )
+ return cqat_model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Check if clustering has produced the correct result."""
+ number_of_clusters = kwargs.get("number_of_clusters")
+ if not number_of_clusters:
+ raise ValueError(
+ """
+ Expected check_preserved_quantize to have argument number_of_clusters.
+ """
+ )
+
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != number_of_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ number_of_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
+ return True
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return the clustering stripped model."""
+ return self._strip_clustering_wrapper(model)
+
class RewriteRegistry(Registry[Rewrite]):
"""Registry rewrite functions."""
@@ -176,7 +226,7 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- FullyConnectedRewrite("fully-connected", fc_rewrite),
+ GenericRewrite("fully-connected", fc_rewrite),
Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
]
@@ -200,7 +250,7 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
- is_qat = isinstance(rewrite, QATRewrite)
+
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -218,9 +268,9 @@ class RewritingOptimizer(Optimizer):
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
rewrite=rewrite,
+ is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite),
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
- is_qat=is_qat,
train_params=self.optimizer_configuration.train_params,
)
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 4b9821c..88efa23 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -73,15 +73,15 @@ class TrainingParameters:
checkpoint_at: list | None = None
-def train(
+def train( # pylint: disable=too-many-arguments
source_model: str,
unmodified_model: Any,
output_model: str,
input_tfrec: str,
rewrite: Callable,
+ is_qat: bool,
input_tensors: list,
output_tensors: list,
- is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
@@ -383,15 +383,15 @@ def train_in_dir(
)
input_shape = teacher.shape_from_name[input_name][1:]
+
output_shape = teacher.shape_from_name[output_name][1:]
- model = rewrite(input_shape, output_shape)
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = rewrite.quantize(model) # type: ignore[attr-defined]
- model = model_compile(model, optimizer, loss_fn)
+ model = create_model(
+ rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ )
logger.info(model.summary())
@@ -432,16 +432,14 @@ def train_in_dir(
callbacks = []
callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
-
- output_filenames = [] # type: list[str]
+ output_filenames: list = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
]
-
model, output_filenames = model_fit(
model,
train_params,
- checkpoints.copy(),
+ checkpoints,
optimizer,
dataset,
callbacks,
@@ -452,22 +450,35 @@ def train_in_dir(
output_name,
model_is_quantized,
output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ post_process=True,
)
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
if model_is_quantized and is_qat:
- model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined]
- model,
- )
+ model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
+ checkpoints = (
+ train_params.checkpoint_at if train_params.checkpoint_at else []
+ ) + [train_params.steps]
+ output_filenames = []
+
+ if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined]
+ rewrite.training_callbacks() # type: ignore[attr-defined]
+ ).issubset(callbacks):
+ callbacks.pop(-1)
+
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
model = model_compile(model, optimizer, loss_fn)
- callbacks.pop(-1)
- output_filenames = []
-
model, output_filenames = model_fit(
model,
train_params,
- checkpoints.copy(),
+ checkpoints,
optimizer,
dataset,
callbacks,
@@ -478,22 +489,50 @@ def train_in_dir(
output_name,
model_is_quantized,
output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
)
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
teacher.close()
return output_filenames
def model_compile(
- model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses
-) -> tf.keras.Model:
+ model: keras.Model,
+ optimizer: keras.optimizers.Nadam,
+ loss_fn: keras.losses.Loss,
+) -> keras.Model:
"""Compiles a tflite model."""
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
return model
-def model_fit(
- model: tf.keras.Model,
+def create_model( # pylint: disable=too-many-arguments
+ rewrite: Callable,
+ input_shape: int,
+ output_shape: int,
+ optimizer: Callable,
+ loss_fn: Callable,
+ model_is_quantized: bool,
+ model_to_load_from: keras.model | None = None,
+) -> keras.Model:
+ """Create a model, optionally from another."""
+ model = rewrite(input_shape, output_shape)
+ if model_is_quantized:
+ model = rewrite.quantize(model) # type: ignore[attr-defined]
+ model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
+ if model_to_load_from:
+ model.set_weights(model_to_load_from.get_weights())
+ return model
+
+
+def model_fit( # pylint: disable=too-many-arguments
+ model: keras.Model,
train_params: TrainingParameters,
checkpoints: list,
optimizer: tf.optimizers.Nadam,
@@ -506,8 +545,12 @@ def model_fit(
output_name: str,
model_is_quantized: bool,
output_filenames: list,
-) -> tuple[tf.keras.Model, list]:
- """Train the model."""
+ input_shape: int,
+ output_shape: int,
+ loss_fn: Callable,
+ post_process: bool = False,
+) -> keras.Model:
+ """Train a tflite model."""
steps_so_far = 0
while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
@@ -534,18 +577,34 @@ def model_fit(
checkpoint_filename = (
filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext
)
+ # If post processing we are stripping the clustering/pruning layers below
+ # Thus copy the model before saving, so training can continue
+ if post_process:
+ model_to_save = create_model(
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ model_to_load_from=model,
+ )
+ else:
+ model_to_save = model
else:
checkpoint_filename = str(output_filename)
-
- if steps_so_far == train_params.steps:
- model = rewrite.post_process(model) # type: ignore[attr-defined]
+ model_to_save = model
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if post_process:
+ model_to_save = rewrite.post_process( # type: ignore[attr-defined]
+ model_to_save
+ )
save_as_tflite(
- model,
- str(checkpoint_filename),
+ model_to_save,
+ checkpoint_filename,
input_name,
replace.shape_from_name[input_name],
output_name,
@@ -553,7 +612,8 @@ def model_fit(
model_is_quantized,
)
output_filenames.append(checkpoint_filename)
- return model, output_filenames
+
+ return model_to_save, output_filenames
def save_as_tflite(
diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py
index 72931c0..7cc383e 100644
--- a/src/mlia/nn/rewrite/library/fc_clustering_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_clustering_layer.py
@@ -9,7 +9,7 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model:
"""Generate TensorFlow Lite model for clustering rewrite."""
- clustering_params = {
+ rewrite_params = {
"number_of_clusters": 32,
"cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
}
@@ -21,6 +21,6 @@ def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model:
keras.layers.Dense(units=output_shape),
]
),
- **clustering_params
+ **rewrite_params
)
return model
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index ef4df6a..e502842 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -10,13 +10,14 @@ from typing import cast
from unittest.mock import MagicMock
import pytest
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
ClusterWeights,
)
from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
-from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
+from mlia.nn.rewrite.core.rewrite import GenericRewrite
from mlia.nn.rewrite.core.rewrite import Rewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
@@ -25,17 +26,48 @@ from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.library.fc_clustering_layer import (
+ get_keras_model_clus as fc_clustering_rewrite,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
+class TestRewrite(Rewrite):
+ """Test rewrite class."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Not needed."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Not needed here."""
+ return True
+
+
+def mock_rewrite_function(*_: Any) -> Any:
+ """Mock function to test autoloading of rewrite functions."""
+
+
def test_rewrite() -> None:
"""Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Sparsity24Rewrite(
+ rewrite = TestRewrite(
"BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
)
with pytest.raises(RuntimeError):
@@ -45,7 +77,7 @@ def test_rewrite() -> None:
@pytest.mark.parametrize(
"rewrite_name, callbacks_length, instance",
[
- ("fully-connected", 0, Rewrite),
+ ("fully-connected", 0, GenericRewrite),
("fully-connected-clustering", 0, ClusteringRewrite),
("fully-connected-sparsity24", 1, Sparsity24Rewrite),
],
@@ -72,8 +104,8 @@ def test_rewrite_selection(
def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
- """Test get_rewrite function only supports rewrite types
- fully-connected, fully-connected-clustering and fully-connected-sparsity24."""
+ """Test get_rewrite function only supports rewrite type fully-connected,
+ fully-connected-clustering and fully-connected-sparsity24."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -88,28 +120,61 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
+def test_rewrite_fully_connected_clustering() -> None:
+ """Check that model has the set number of clusters"""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model, number_of_clusters=32)
+
+
+def test_rewrite_fully_connected_clustering_error_handling() -> None:
+ """Check that model has the set number of clusters
+ and that when quantized the number of clusters
+ remain."""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ with pytest.raises(
+ ValueError,
+ match=(
+ r"Expected check_preserved_quantize to have argument number_of_clusters"
+ ),
+ ):
+ rewrite.check_optimization(model, bad_arg_name=25)
+
+
@pytest.mark.parametrize(
- "rewrite_type, expected_layers",
+ "rewrite_type, expected_layers, quant",
[
- ["fully-connected", [keras.layers.Reshape, keras.layers.Dense]],
- ["fully-connected-clustering", [ClusterWeights, ClusterWeights]],
+ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
],
)
-def test_rewriting_optimizer(
+def test_rewriting_optimizer( # pylint: disable=too-many-locals
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
+ test_tflite_model: Path,
+ test_tfrecord: Path,
rewrite_type: str,
expected_layers: list[object],
+ quant: bool,
) -> None:
"""Test fc_layer rewrite process with rewrite type fully-connected."""
+
+ tfrecord = test_tfrecord if quant else test_tfrecord_fp32
+ tflite_model = test_tflite_model if quant else test_tflite_model_fp32
+
config_obj = RewriteConfiguration(
rewrite_type,
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
- test_tfrecord_fp32,
+ tfrecord,
train_params=MockTrainingParameters(),
)
- test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(tflite_model, config_obj)
rewrite_function = RewritingOptimizer.registry.items[
test_obj.optimizer_configuration.optimization_target
]
@@ -132,8 +197,8 @@ def test_register_rewrite_function() -> None:
"""Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 371c79f..94c99ff 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -14,15 +14,13 @@ import pytest
import tensorflow as tf
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
-from mlia.nn.rewrite.core.rewrite import QATRewrite
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
+from tests.test_nn_rewrite_core_rewrite import TestRewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -56,20 +54,16 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
- mock_rewrite = FullyConnectedRewrite(
- name="replace",
- rewrite_fn=fc_rewrite,
- )
- is_qat = isinstance(mock_rewrite, QATRewrite)
+ mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
output_model=str(output_file),
input_tfrec=str(tfrecord),
rewrite=mock_rewrite,
+ is_qat=False,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- is_qat=is_qat,
train_params=train_params,
)