aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index d4aac56..487784d 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -17,7 +17,7 @@ from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.tensorflow.config import TFLiteModel
-from tests.utils.rewrite import TestTrainingParameters
+from tests.utils.rewrite import MockTrainingParameters
def mock_rewrite_function(*_: Any) -> Any:
@@ -69,7 +69,7 @@ def test_rewriting_optimizer(
"fully_connected",
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
- train_params=TestTrainingParameters(),
+ train_params=MockTrainingParameters(),
)
test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)