aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-19 16:35:57 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:06:17 +0100
commit3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch)
treead81fb520a965bd3a3c7c983833b7cd48f9b8dea /tests
parentf3e6597dd50ec70f043d692b773f2d9fd31519ae (diff)
downloadmlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py39
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_join.py6
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_record.py4
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py6
-rw-r--r--tests/test_nn_rewrite_core_train.py76
-rw-r--r--tests/test_nn_rewrite_core_utils.py33
-rw-r--r--tests/test_nn_rewrite_core_utils_numpy_tfrecord.py25
-rw-r--r--tests/test_nn_tensorflow_config.py40
-rw-r--r--tests/test_nn_tensorflow_tflite_graph.py30
-rw-r--r--tests/utils/rewrite.py18
10 files changed, 189 insertions, 88 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index c42b8cb..bb2423f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,6 +5,7 @@ import shutil
from pathlib import Path
from typing import Callable
from typing import Generator
+from unittest.mock import MagicMock
import numpy as np
import pytest
@@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
+from tests.utils.rewrite import TestTrainingParameters
@pytest.fixture(scope="session", name="test_resources_path")
@@ -168,16 +170,12 @@ def _write_tfrecord(
writer.write({input_name: data_generator()})
-@pytest.fixture(scope="session", name="test_tfrecord")
-def fixture_test_tfrecord(
- tmp_path_factory: pytest.TempPathFactory,
+def create_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory, random_data: Callable
) -> Generator[Path, None, None]:
"""Create a tfrecord with random data matching fixture 'test_tflite_model'."""
tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_int8.tfrecord"
-
- def random_data() -> np.ndarray:
- return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+ tfrecord_file = tmp_path / "test.tfrecord"
_write_tfrecord(tfrecord_file, random_data)
@@ -186,19 +184,36 @@ def fixture_test_tfrecord(
shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", name="test_tfrecord")
+def fixture_test_tfrecord(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create a tfrecord with random data matching fixture 'test_tflite_model'."""
+
+ def random_data() -> np.ndarray:
+ return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
@pytest.fixture(scope="session", name="test_tfrecord_fp32")
def fixture_test_tfrecord_fp32(
tmp_path_factory: pytest.TempPathFactory,
) -> Generator[Path, None, None]:
"""Create tfrecord with random data matching fixture 'test_tflite_model_fp32'."""
- tmp_path = tmp_path_factory.mktemp("tfrecords")
- tfrecord_file = tmp_path / "test_fp32.tfrecord"
def random_data() -> np.ndarray:
return np.random.rand(1, 28, 28, 1).astype(np.float32)
- _write_tfrecord(tfrecord_file, random_data)
+ yield from create_tfrecord(tmp_path_factory, random_data)
- yield tfrecord_file
- shutil.rmtree(tmp_path)
+@pytest.fixture(scope="session", autouse=True)
+def set_training_steps() -> Generator[None, None, None]:
+ """Speed up tests by using TestTrainingParameters."""
+ with pytest.MonkeyPatch.context() as monkeypatch:
+ monkeypatch.setattr(
+ "mlia.nn.select._get_rewrite_train_params",
+ MagicMock(return_value=TestTrainingParameters()),
+ )
+ yield
diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py
index cb3e4e2..0cb121e 100644
--- a/tests/test_nn_rewrite_core_graph_edit_join.py
+++ b/tests/test_nn_rewrite_core_graph_edit_join.py
@@ -10,7 +10,7 @@ import pytest
from mlia.nn.rewrite.core.graph_edit.cut import cut_model
from mlia.nn.rewrite.core.graph_edit.join import append_relabel
from mlia.nn.rewrite.core.graph_edit.join import join_models
-from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.tensorflow.tflite_graph import load_fb
from tests.utils.rewrite import models_are_equal
@@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None:
)
assert joined_file.is_file()
- orig_model = load(str(test_tflite_model))
- joined_model = load(str(joined_file))
+ orig_model = load_fb(str(test_tflite_model))
+ joined_model = load_fb(str(joined_file))
assert models_are_equal(orig_model, joined_model)
diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py
index cd728af..41b9c50 100644
--- a/tests/test_nn_rewrite_core_graph_edit_record.py
+++ b/tests/test_nn_rewrite_core_graph_edit_record.py
@@ -3,15 +3,13 @@
"""Tests for module mlia.nn.rewrite.graph_edit.record."""
from pathlib import Path
-import pytest
import tensorflow as tf
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-@pytest.mark.parametrize("batch_size", (None, 1, 2))
-def test_record_model(
+def check_record_model(
test_tflite_model: Path,
tmp_path: Path,
test_tfrecord: Path,
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index b98971e..2542db2 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -12,6 +12,7 @@ import pytest
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import Rewriter
from mlia.nn.tensorflow.config import TFLiteModel
+from tests.utils.rewrite import TestTrainingParameters
@pytest.mark.parametrize(
@@ -32,12 +33,14 @@ def test_rewrite_configuration(
None,
)
+ assert config_obj.optimization_target in str(config_obj)
+
rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
assert isinstance(rewriter_obj, Rewriter)
-def test_rewriter(
+def test_rewriting_optimizer(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
) -> None:
@@ -46,6 +49,7 @@ def test_rewriter(
"fully_connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
+ train_params=TestTrainingParameters(),
)
test_obj = Rewriter(test_tflite_model_fp32, config_obj)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 3c2ef3e..4493671 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -4,6 +4,7 @@
# pylint: disable=too-many-arguments
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
@@ -12,10 +13,13 @@ import numpy as np
import pytest
import tensorflow as tf
-from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import augment_fn_twins
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.utils.rewrite import TestTrainingParameters
def replace_fully_connected_with_conv(
@@ -41,15 +45,8 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- batch_size: int = 1,
- verbose: bool = False,
- show_progress: bool = False,
- augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
- "none"
- ],
- lr_schedule: LearningRateSchedule = "cosine",
+ train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
- num_procs: int = 1,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
@@ -62,14 +59,7 @@ def check_train(
replace_fn=replace_fully_connected_with_conv,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- learning_rate_schedule=lr_schedule,
- num_procs=num_procs,
+ train_params=train_params,
)
assert len(result) == 2
assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
@@ -79,7 +69,6 @@ def check_train(
@pytest.mark.parametrize(
(
"batch_size",
- "verbose",
"show_progress",
"augmentation_preset",
"lr_schedule",
@@ -87,14 +76,13 @@ def check_train(
"num_procs",
),
(
- (1, False, False, augmentation_presets["none"], "cosine", False, 2),
- (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
- (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
(
1,
False,
- False,
- augmentation_presets["mix_gaussian_large"],
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
"cosine",
False,
2,
@@ -105,7 +93,6 @@ def test_train(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
- verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
lr_schedule: LearningRateSchedule,
@@ -116,13 +103,14 @@ def test_train(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- augmentation_preset=augmentation_preset,
- lr_schedule=lr_schedule,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
use_unmodified_model=use_unmodified_model,
- num_procs=num_procs,
)
@@ -131,11 +119,13 @@ def test_train_invalid_schedule(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid schedule."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule", # type: ignore
+ train_params=TestTrainingParameters(
+ learning_rate_schedule="unknown_schedule",
+ ),
)
@@ -144,11 +134,13 @@ def test_train_invalid_augmentation(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid augmentation."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ train_params=TestTrainingParameters(
+ augmentations=(1.0, 2.0, 3.0),
+ ),
)
@@ -159,3 +151,19 @@ def test_mixup() -> None:
assert src.shape == dst.shape
assert np.all(dst >= 0.0)
assert np.all(dst <= 3.0)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_error",
+ [
+ (AUGMENTATION_PRESETS["none"], does_not_raise()),
+ (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()),
+ ((None,) * 3, pytest.raises(AssertionError)),
+ ],
+)
+def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
+ """Test function augment_fn()."""
+ dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]})
+ with expected_error:
+ fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
+ assert len(fn_twins) == 2
diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py
deleted file mode 100644
index d806a7b..0000000
--- a/tests/test_nn_rewrite_core_utils.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
-# SPDX-License-Identifier: Apache-2.0
-"""Tests for module mlia.nn.rewrite.utils."""
-from pathlib import Path
-
-import pytest
-import tensorflow as tf
-from tensorflow.lite.python.schema_py_generated import ModelT
-
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
-from tests.utils.rewrite import models_are_equal
-
-
-def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
- """Test the load/save functions for TensorFlow Lite models."""
- with pytest.raises(FileNotFoundError):
- load("THIS_IS_NOT_A_REAL_FILE")
-
- model = load(test_tflite_model)
- assert isinstance(model, ModelT)
- assert model.subgraphs
-
- output_file = tmp_path / "test.tflite"
- assert not output_file.is_file()
- save(model, output_file)
- assert output_file.is_file()
-
- model_copy = load(str(output_file))
- assert models_are_equal(model, model_copy)
-
- # Double check that the TensorFlow Lite Interpreter can still load the file.
- tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
index 7fc8048..d030350 100644
--- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
+++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py
@@ -5,6 +5,10 @@ from __future__ import annotations
from pathlib import Path
+import pytest
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec
@@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None:
sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file))
assert output_file.is_file()
assert numpytf_count(str(output_file)) == 1
+
+
+def test_make_decode_fn(test_tfrecord: Path) -> None:
+ """Test function make_decode_fn()."""
+ decode = make_decode_fn(str(test_tfrecord))
+ dataset = tf.data.TFRecordDataset(str(test_tfrecord))
+ features = decode(next(iter(dataset)))
+ assert isinstance(features, dict)
+ assert len(features) == 1
+ key, val = next(iter(features.items()))
+ assert isinstance(key, str)
+ assert isinstance(val, tf.Tensor)
+ assert val.dtype == tf.int8
+
+ with pytest.raises(FileNotFoundError):
+ make_decode_fn(str(test_tfrecord) + "_")
+
+
+def test_numpytf_count(test_tfrecord: Path) -> None:
+ """Test function numpytf_count()."""
+ assert numpytf_count(test_tfrecord) == 3
diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py
index 656619d..48aec0a 100644
--- a/tests/test_nn_tensorflow_config.py
+++ b/tests/test_nn_tensorflow_config.py
@@ -4,13 +4,28 @@
from contextlib import ExitStack as does_not_raise
from pathlib import Path
from typing import Any
+from typing import Generator
+import numpy as np
import pytest
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.tensorflow.config import get_model
from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import ModelConfiguration
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.config import TfModel
+from tests.conftest import create_tfrecord
+
+
+def test_model_configuration(test_keras_model: Path) -> None:
+ """Test ModelConfiguration class."""
+ model = ModelConfiguration(model_path=test_keras_model)
+ assert test_keras_model.match(model.model_path)
+ with pytest.raises(NotImplementedError):
+ model.convert_to_keras("keras_model.h5")
+ with pytest.raises(NotImplementedError):
+ model.convert_to_tflite("model.tflite")
def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None:
@@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
@pytest.mark.parametrize(
"model_path, expected_type, expected_error",
[
- ("test.tflite", TFLiteModel, does_not_raise()),
+ ("test.tflite", TFLiteModel, pytest.raises(ValueError)),
("test.h5", KerasModel, does_not_raise()),
("test.hdf5", KerasModel, does_not_raise()),
(
@@ -73,3 +88,26 @@ def test_get_model_dir(
"""Test TensorFlow Lite model type."""
model = get_model(str(test_models_path / model_path))
assert isinstance(model, expected_type)
+
+
+@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3")
+def fixture_test_tfrecord_fp32_batch_3(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3."""
+
+ def random_data() -> np.ndarray:
+ return np.random.rand(3, 28, 28, 1).astype(np.float32)
+
+ yield from create_tfrecord(tmp_path_factory, random_data)
+
+
+def test_tflite_model_call(
+ test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path
+) -> None:
+ """Test inference function of class TFLiteModel."""
+ model = TFLiteModel(test_tflite_model_fp32, batch_size=2)
+ data = numpytf_read(test_tfrecord_fp32_batch_3)
+ for named_input in data.as_numpy_iterator():
+ res = model(named_input)
+ assert res
diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py
index cd1fad6..3512cdd 100644
--- a/tests/test_nn_tensorflow_tflite_graph.py
+++ b/tests/test_nn_tensorflow_tflite_graph.py
@@ -1,15 +1,22 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the tflite_graph module."""
import json
from pathlib import Path
+import pytest
+import tensorflow as tf
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import Op
from mlia.nn.tensorflow.tflite_graph import parse_subgraphs
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.tflite_graph import TensorInfo
from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.tflite_graph import TFL_TYPE
+from tests.utils.rewrite import models_are_equal
def test_tensor_info() -> None:
@@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None:
assert TFL_OP[oper.type] in TFL_OP
assert len(oper.inputs) > 0
assert len(oper.outputs) > 0
+
+
+def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the load/save functions for TensorFlow Lite models."""
+ with pytest.raises(FileNotFoundError):
+ load_fb("THIS_IS_NOT_A_REAL_FILE")
+
+ model = load_fb(test_tflite_model)
+ assert isinstance(model, ModelT)
+ assert model.subgraphs
+
+ output_file = tmp_path / "test.tflite"
+ assert not output_file.is_file()
+ save_fb(model, output_file)
+ assert output_file.is_file()
+
+ model_copy = load_fb(str(output_file))
+ assert models_are_equal(model, model_copy)
+
+ # Double check that the TensorFlow Lite Interpreter can still load the file.
+ tf.lite.Interpreter(model_path=str(output_file))
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
index 4264b4b..739bb11 100644
--- a/tests/utils/rewrite.py
+++ b/tests/utils/rewrite.py
@@ -3,8 +3,12 @@
"""Common test utils for the rewrite tests."""
from __future__ import annotations
+from typing import Any
+
from tensorflow.lite.python.schema_py_generated import ModelT
+from mlia.nn.rewrite.core.train import TrainingParameters
+
def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
"""Check that the two models are equal."""
@@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
return False # Tensor from graph1 not found in other graph.")
return True
+
+
+class TestTrainingParameters(
+ TrainingParameters
+): # pylint: disable=too-few-public-methods
+ """
+ TrainingParameter class for rewrites with different default values.
+
+ To speed things up for the unit tests.
+ """
+
+ def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None:
+ """Initialize TrainingParameters with different defaults."""
+ super().__init__(*args, steps=steps, **kwargs) # type: ignore