aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
new file mode 100644
index 0000000..b98971e
--- /dev/null
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -0,0 +1,55 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.core.rewrite."""
+from __future__ import annotations
+
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
+from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.tensorflow.config import TFLiteModel
+
+
+@pytest.mark.parametrize(
+ "rewrite_name, expected_error",
+ [
+ ("fully_connected", does_not_raise()),
+ ("random", does_not_raise()),
+ ],
+)
+def test_rewrite_configuration(
+ test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
+) -> None:
+ """Test get_rewrite function only supports rewrite type fully_connected."""
+ with expected_error:
+ config_obj = RewriteConfiguration(
+ rewrite_name,
+ ["sample_node_start", "sample_node_end"],
+ None,
+ )
+
+ rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name
+ assert isinstance(rewriter_obj, Rewriter)
+
+
+def test_rewriter(
+ test_tflite_model_fp32: Path,
+ test_tfrecord_fp32: Path,
+) -> None:
+ """Test fc_layer rewrite process with rewrite type fully_connected."""
+ config_obj = RewriteConfiguration(
+ "fully_connected",
+ ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
+ test_tfrecord_fp32,
+ )
+
+ test_obj = Rewriter(test_tflite_model_fp32, config_obj)
+ test_obj.apply_optimization()
+ trained_model = test_obj.get_model()
+
+ assert isinstance(trained_model, TFLiteModel)