diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py new file mode 100644 index 0000000..b98971e --- /dev/null +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.core.rewrite.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any + +import pytest + +from mlia.nn.rewrite.core.rewrite import RewriteConfiguration +from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.tensorflow.config import TFLiteModel + + +@pytest.mark.parametrize( + "rewrite_name, expected_error", + [ + ("fully_connected", does_not_raise()), + ("random", does_not_raise()), + ], +) +def test_rewrite_configuration( + test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any +) -> None: + """Test get_rewrite function only supports rewrite type fully_connected.""" + with expected_error: + config_obj = RewriteConfiguration( + rewrite_name, + ["sample_node_start", "sample_node_end"], + None, + ) + + rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) + assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name + assert isinstance(rewriter_obj, Rewriter) + + +def test_rewriter( + test_tflite_model_fp32: Path, + test_tfrecord_fp32: Path, +) -> None: + """Test fc_layer rewrite process with rewrite type fully_connected.""" + config_obj = RewriteConfiguration( + "fully_connected", + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + test_tfrecord_fp32, + ) + + test_obj = Rewriter(test_tflite_model_fp32, config_obj) + test_obj.apply_optimization() + trained_model = test_obj.get_model() + + assert isinstance(trained_model, TFLiteModel) |