diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index d4aac56..487784d 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -17,7 +17,7 @@ from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.tensorflow.config import TFLiteModel -from tests.utils.rewrite import TestTrainingParameters +from tests.utils.rewrite import MockTrainingParameters def mock_rewrite_function(*_: Any) -> Any: @@ -69,7 +69,7 @@ def test_rewriting_optimizer( "fully_connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, - train_params=TestTrainingParameters(), + train_params=MockTrainingParameters(), ) test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) |