aboutsummaryrefslogtreecommitdiff
path: root/python/pytests/test_example.py
blob: e03997b98ed754e41b921873eb87b3ead67d8fdb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3

# Copyright (c) 2024, ARM Limited.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import json
import pathlib
import subprocess
import serializer.tosa_serializer as ts


def test_example(request):
    """Testing that pytest and the Python serialization library work"""

    # Creating an example TOSA region
    ser = ts.TosaSerializer("ser")
    ser.currRegion.currBasicBlock.addTensor("t1", [3, 4, 5], ts.DType.FP16)
    ser.currRegion.currBasicBlock.addTensor("t2", [2, 2], ts.DType.INT32)
    ser.currRegion.currBasicBlock.addInput("t1")
    ser.currRegion.currBasicBlock.addOutput("t2")

    attr = ts.TosaSerializerAttribute()
    attr.ConvAttribute([1, 1], [2, 2], [3, 3], 4, 5, True, ts.DType.FP32)
    ser.currRegion.currBasicBlock.addOperator(
        ts.TosaOp.Op().CONV2D, ["t1"], ["t2"], attr
    )

    # Defining filepaths
    testname = request.node.name

    base_dir = (pathlib.Path(__file__).parent / "../..").resolve()
    tmp_dir = base_dir / "python/pytests/tmp"
    tosa_file = tmp_dir / f"{testname}.tosa"
    schema_file = base_dir / "schema/tosa.fbs"
    flatc = base_dir / "third_party/flatbuffers/flatc"

    # Serializing to flatbuffer and writing to a temporary file
    with open(tosa_file, "wb") as f:
        f.write(ser.serialize())

    # Using flatc to convert the flatbuffer to strict json
    _ = subprocess.run(
        [flatc, "--json", "--strict-json", "-o", tmp_dir, schema_file, "--", tosa_file],
        check=True,
    )

    # Opening json file generated by previous command
    json_file = tmp_dir / f"{testname}.json"
    with open(json_file, encoding="utf-8") as f:
        serialized = json.load(f)

    assert serialized["regions"] == [
        {
            "name": "main",
            "blocks": [
                {
                    "name": "main",
                    "inputs": ["t1"],
                    "outputs": ["t2"],
                    "operators": [
                        {
                            "op": "CONV2D",
                            "attribute_type": "ConvAttribute",
                            "attribute": {
                                "pad": [1, 1],
                                "stride": [2, 2],
                                "dilation": [3, 3],
                                "input_zp": 4,
                                "weight_zp": 5,
                                "local_bound": True,
                                "acc_type": "FP32",
                            },
                            "inputs": ["t1"],
                            "outputs": ["t2"],
                        }
                    ],
                    "tensors": [
                        {"name": "t1", "shape": [3, 4, 5], "type": "FP16"},
                        {"name": "t2", "shape": [2, 2], "type": "INT32"},
                    ],
                }
            ],
        }
    ]