aboutsummaryrefslogtreecommitdiff
path: root/tests/utils/rewrite.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-03-20 18:07:54 +0000
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 15:42:55 +0100
commit62768232c5fe4ed6b87136c336b65e13d030e9d4 (patch)
tree847c36a2f7e092982bc1d7a66d0bf601447c8d20 /tests/utils/rewrite.py
parent446c379c92e15ad8f24ed0db853dd0fc9c271151 (diff)
downloadmlia-62768232c5fe4ed6b87136c336b65e13d030e9d4.tar.gz
MLIA-843 Add unit tests for module mlia.nn.rewrite
Note: The unit tests mostly call the main functions from the respective modules only. Change-Id: Ib2ce5c53d0c3eb222b8b8be42fba33ac8e007574 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'tests/utils/rewrite.py')
-rw-r--r--tests/utils/rewrite.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py
new file mode 100644
index 0000000..4264b4b
--- /dev/null
+++ b/tests/utils/rewrite.py
@@ -0,0 +1,27 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common test utils for the rewrite tests."""
+from __future__ import annotations
+
+from tensorflow.lite.python.schema_py_generated import ModelT
+
+
+def models_are_equal(model1: ModelT, model2: ModelT) -> bool:
+ """Check that the two models are equal."""
+ if len(model1.subgraphs) != len(model2.subgraphs):
+ return False
+
+ for graph1, graph2 in zip(model1.subgraphs, model2.subgraphs):
+ if graph1.name != graph2.name or len(graph1.tensors) != len(graph2.tensors):
+ return False
+ for tensor1 in graph1.tensors:
+ for tensor2 in graph2.tensors:
+ if tensor1.name == tensor2.name:
+ if (
+ tensor1.shape == tensor2.shape
+ ).all() or tensor1.type == tensor2.type:
+ break
+ else:
+ return False # Tensor from graph1 not found in other graph.")
+
+ return True