diff options
Diffstat (limited to 'tests/utils/rewrite.py')
-rw-r--r-- | tests/utils/rewrite.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py index 4264b4b..739bb11 100644 --- a/tests/utils/rewrite.py +++ b/tests/utils/rewrite.py @@ -3,8 +3,12 @@ """Common test utils for the rewrite tests.""" from __future__ import annotations +from typing import Any + from tensorflow.lite.python.schema_py_generated import ModelT +from mlia.nn.rewrite.core.train import TrainingParameters + def models_are_equal(model1: ModelT, model2: ModelT) -> bool: """Check that the two models are equal.""" @@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool: return False # Tensor from graph1 not found in other graph.") return True + + +class TestTrainingParameters( + TrainingParameters +): # pylint: disable=too-few-public-methods + """ + TrainingParameter class for rewrites with different default values. + + To speed things up for the unit tests. + """ + + def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None: + """Initialize TrainingParameters with different defaults.""" + super().__init__(*args, steps=steps, **kwargs) # type: ignore |