diff options
author | Kaushik Varadharajan <kaushik.varadharajan@arm.com> | 2024-06-11 21:31:13 +0000 |
---|---|---|
committer | Kaushik Varadharajan <kaushik.varadharajan@arm.com> | 2024-06-11 21:32:45 +0000 |
commit | ca6fa5fbc4af9f486d87d36c7baeff924edd0534 (patch) | |
tree | 9f29e3144929d762598ab57584e2d5d26143c1c9 /python/pytests | |
parent | 76a2911f03922a5d2a8236bdd1d9a051cf72eccc (diff) | |
download | serialization_lib-ca6fa5fbc4af9f486d87d36c7baeff924edd0534.tar.gz |
Add pytest config and example test
Tests can be run with the `pytest` command from the repository's
root directory.
Signed-off-by: Kaushik Varadharajan <kaushik.varadharajan@arm.com>
Change-Id: Id1ead34da927d4455964cb211e4fc0c6294e4bdf
Diffstat (limited to 'python/pytests')
-rw-r--r-- | python/pytests/__init__.py | 0 | ||||
-rw-r--r-- | python/pytests/conftest.py | 36 | ||||
-rw-r--r-- | python/pytests/test_example.py | 95 |
3 files changed, 131 insertions, 0 deletions
diff --git a/python/pytests/__init__.py b/python/pytests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/python/pytests/__init__.py diff --git a/python/pytests/conftest.py b/python/pytests/conftest.py new file mode 100644 index 0000000..b595a01 --- /dev/null +++ b/python/pytests/conftest.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2024, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import shutil + + +def pytest_sessionstart(): + """Initializes temporary directory.""" + + base_dir = (pathlib.Path(__file__).parent / "../..").resolve() + tmp_dir = base_dir / "python/pytests/tmp" + + tmp_dir.mkdir(exist_ok=True) + + +def pytest_sessionfinish(): + """Cleaning up temporary files.""" + + base_dir = (pathlib.Path(__file__).parent / "../..").resolve() + tmp_dir = base_dir / "python/pytests/tmp" + + shutil.rmtree(tmp_dir) diff --git a/python/pytests/test_example.py b/python/pytests/test_example.py new file mode 100644 index 0000000..e03997b --- /dev/null +++ b/python/pytests/test_example.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2024, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import pathlib +import subprocess +import serializer.tosa_serializer as ts + + +def test_example(request): + """Testing that pytest and the Python serialization library work""" + + # Creating an example TOSA region + ser = ts.TosaSerializer("ser") + ser.currRegion.currBasicBlock.addTensor("t1", [3, 4, 5], ts.DType.FP16) + ser.currRegion.currBasicBlock.addTensor("t2", [2, 2], ts.DType.INT32) + ser.currRegion.currBasicBlock.addInput("t1") + ser.currRegion.currBasicBlock.addOutput("t2") + + attr = ts.TosaSerializerAttribute() + attr.ConvAttribute([1, 1], [2, 2], [3, 3], 4, 5, True, ts.DType.FP32) + ser.currRegion.currBasicBlock.addOperator( + ts.TosaOp.Op().CONV2D, ["t1"], ["t2"], attr + ) + + # Defining filepaths + testname = request.node.name + + base_dir = (pathlib.Path(__file__).parent / "../..").resolve() + tmp_dir = base_dir / "python/pytests/tmp" + tosa_file = tmp_dir / f"{testname}.tosa" + schema_file = base_dir / "schema/tosa.fbs" + flatc = base_dir / "third_party/flatbuffers/flatc" + + # Serializing to flatbuffer and writing to a temporary file + with open(tosa_file, "wb") as f: + f.write(ser.serialize()) + + # Using flatc to convert the flatbuffer to strict json + _ = subprocess.run( + [flatc, "--json", "--strict-json", "-o", tmp_dir, schema_file, "--", tosa_file], + check=True, + ) + + # Opening json file generated by previous command + json_file = tmp_dir / f"{testname}.json" + with open(json_file, encoding="utf-8") as f: + serialized = json.load(f) + + assert serialized["regions"] == [ + { + "name": "main", + "blocks": [ + { + "name": "main", + "inputs": ["t1"], + "outputs": ["t2"], + "operators": [ + { + "op": "CONV2D", + "attribute_type": "ConvAttribute", + "attribute": { + "pad": [1, 1], + "stride": [2, 2], + "dilation": [3, 3], + "input_zp": 4, + "weight_zp": 5, + "local_bound": True, + "acc_type": "FP32", + }, + "inputs": ["t1"], + "outputs": ["t2"], + } + ], + "tensors": [ + {"name": "t1", "shape": [3, 4, 5], "type": "FP16"}, + {"name": "t2", "shape": [2, 2], "type": "INT32"}, + ], + } + ], + } + ] |