aboutsummaryrefslogtreecommitdiff
path: root/tests/utils/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils/rewrite.py')
-rw-r--r--tests/utils/rewrite.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
index 4264b4b..739bb11 100644
--- a/tests/utils/rewrite.py
+++ b/tests/utils/rewrite.py
@@ -3,8 +3,12 @@
"""Common test utils for the rewrite tests."""
from __future__ import annotations
+from typing import Any
+
from tensorflow.lite.python.schema_py_generated import ModelT
+from mlia.nn.rewrite.core.train import TrainingParameters
+
def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
"""Check that the two models are equal."""
@@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
return False # Tensor from graph1 not found in other graph.")
return True
+
+
+class TestTrainingParameters(
+ TrainingParameters
+): # pylint: disable=too-few-public-methods
+ """
+ TrainingParameter class for rewrites with different default values.
+
+ To speed things up for the unit tests.
+ """
+
+ def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None:
+ """Initialize TrainingParameters with different defaults."""
+ super().__init__(*args, steps=steps, **kwargs) # type: ignore