aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorKaushik Varadharajan <kaushik.varadharajan@arm.com>2024-06-11 21:31:13 +0000
committerKaushik Varadharajan <kaushik.varadharajan@arm.com>2024-06-11 21:32:45 +0000
commitca6fa5fbc4af9f486d87d36c7baeff924edd0534 (patch)
tree9f29e3144929d762598ab57584e2d5d26143c1c9 /python
parent76a2911f03922a5d2a8236bdd1d9a051cf72eccc (diff)
downloadserialization_lib-main.tar.gz
Add pytest config and example testHEADmain
Tests can be run with the `pytest` command from the repository's root directory. Signed-off-by: Kaushik Varadharajan <kaushik.varadharajan@arm.com> Change-Id: Id1ead34da927d4455964cb211e4fc0c6294e4bdf
Diffstat (limited to 'python')
-rw-r--r--python/pytests/__init__.py0
-rw-r--r--python/pytests/conftest.py36
-rw-r--r--python/pytests/test_example.py95
3 files changed, 131 insertions, 0 deletions
diff --git a/python/pytests/__init__.py b/python/pytests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/python/pytests/__init__.py
diff --git a/python/pytests/conftest.py b/python/pytests/conftest.py
new file mode 100644
index 0000000..b595a01
--- /dev/null
+++ b/python/pytests/conftest.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2024, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pathlib
+import shutil
+
+
+def pytest_sessionstart():
+ """Initializes temporary directory."""
+
+ base_dir = (pathlib.Path(__file__).parent / "../..").resolve()
+ tmp_dir = base_dir / "python/pytests/tmp"
+
+ tmp_dir.mkdir(exist_ok=True)
+
+
+def pytest_sessionfinish():
+ """Cleaning up temporary files."""
+
+ base_dir = (pathlib.Path(__file__).parent / "../..").resolve()
+ tmp_dir = base_dir / "python/pytests/tmp"
+
+ shutil.rmtree(tmp_dir)
diff --git a/python/pytests/test_example.py b/python/pytests/test_example.py
new file mode 100644
index 0000000..e03997b
--- /dev/null
+++ b/python/pytests/test_example.py
@@ -0,0 +1,95 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2024, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import pathlib
+import subprocess
+import serializer.tosa_serializer as ts
+
+
+def test_example(request):
+ """Testing that pytest and the Python serialization library work"""
+
+ # Creating an example TOSA region
+ ser = ts.TosaSerializer("ser")
+ ser.currRegion.currBasicBlock.addTensor("t1", [3, 4, 5], ts.DType.FP16)
+ ser.currRegion.currBasicBlock.addTensor("t2", [2, 2], ts.DType.INT32)
+ ser.currRegion.currBasicBlock.addInput("t1")
+ ser.currRegion.currBasicBlock.addOutput("t2")
+
+ attr = ts.TosaSerializerAttribute()
+ attr.ConvAttribute([1, 1], [2, 2], [3, 3], 4, 5, True, ts.DType.FP32)
+ ser.currRegion.currBasicBlock.addOperator(
+ ts.TosaOp.Op().CONV2D, ["t1"], ["t2"], attr
+ )
+
+ # Defining filepaths
+ testname = request.node.name
+
+ base_dir = (pathlib.Path(__file__).parent / "../..").resolve()
+ tmp_dir = base_dir / "python/pytests/tmp"
+ tosa_file = tmp_dir / f"{testname}.tosa"
+ schema_file = base_dir / "schema/tosa.fbs"
+ flatc = base_dir / "third_party/flatbuffers/flatc"
+
+ # Serializing to flatbuffer and writing to a temporary file
+ with open(tosa_file, "wb") as f:
+ f.write(ser.serialize())
+
+ # Using flatc to convert the flatbuffer to strict json
+ _ = subprocess.run(
+ [flatc, "--json", "--strict-json", "-o", tmp_dir, schema_file, "--", tosa_file],
+ check=True,
+ )
+
+ # Opening json file generated by previous command
+ json_file = tmp_dir / f"{testname}.json"
+ with open(json_file, encoding="utf-8") as f:
+ serialized = json.load(f)
+
+ assert serialized["regions"] == [
+ {
+ "name": "main",
+ "blocks": [
+ {
+ "name": "main",
+ "inputs": ["t1"],
+ "outputs": ["t2"],
+ "operators": [
+ {
+ "op": "CONV2D",
+ "attribute_type": "ConvAttribute",
+ "attribute": {
+ "pad": [1, 1],
+ "stride": [2, 2],
+ "dilation": [3, 3],
+ "input_zp": 4,
+ "weight_zp": 5,
+ "local_bound": True,
+ "acc_type": "FP32",
+ },
+ "inputs": ["t1"],
+ "outputs": ["t2"],
+ }
+ ],
+ "tensors": [
+ {"name": "t1", "shape": [3, 4, 5], "type": "FP16"},
+ {"name": "t2", "shape": [2, 2], "type": "INT32"},
+ ],
+ }
+ ],
+ }
+ ]