diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index b98971e..2542db2 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -12,6 +12,7 @@ import pytest from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import Rewriter from mlia.nn.tensorflow.config import TFLiteModel +from tests.utils.rewrite import TestTrainingParameters @pytest.mark.parametrize( @@ -32,12 +33,14 @@ def test_rewrite_configuration( None, ) + assert config_obj.optimization_target in str(config_obj) + rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name assert isinstance(rewriter_obj, Rewriter) -def test_rewriter( +def test_rewriting_optimizer( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, ) -> None: @@ -46,6 +49,7 @@ def test_rewriter( "fully_connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, + train_params=TestTrainingParameters(), ) test_obj = Rewriter(test_tflite_model_fp32, config_obj) |