aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_graph_edit_cut.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_graph_edit_cut.py')
-rw-r--r--tests/test_nn_rewrite_core_graph_edit_cut.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/tests/test_nn_rewrite_core_graph_edit_cut.py b/tests/test_nn_rewrite_core_graph_edit_cut.py
new file mode 100644
index 0000000..914fdfd
--- /dev/null
+++ b/tests/test_nn_rewrite_core_graph_edit_cut.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module mlia.nn.rewrite.graph_edit.cut."""
+from pathlib import Path
+
+import numpy as np
+import tensorflow as tf
+
+from mlia.nn.rewrite.core.graph_edit.cut import cut_model
+
+
+def test_cut_model(test_tflite_model: Path, tmp_path: Path) -> None:
+ """Test the function cut_model()."""
+ output_file = tmp_path / "out.tflite"
+ cut_model(
+ model_file=test_tflite_model,
+ input_names=["serving_default_input:0"],
+ output_names=["sequential/flatten/Reshape"],
+ subgraph_index=0,
+ output_file=output_file,
+ )
+ assert output_file.is_file()
+
+ interpreter = tf.lite.Interpreter(model_path=str(output_file))
+ output_details = interpreter.get_output_details()
+ assert len(output_details) == 1
+ out = output_details[0]
+ assert "Reshape" in out["name"]
+ assert np.prod(out["shape"]) == 1728