aboutsummaryrefslogtreecommitdiff
path: root/verif/tests/test_tosa_datagenerator.py
blob: 4f3d7fdad0f5921f2089fcbf78624a606440f4f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Tests for the python interface to the data generator library."""
# Copyright (c) 2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path

import numpy as np
import pytest
from generator.datagenerator import GenerateError
from generator.datagenerator import GenerateLibrary

# NOTE: These tests are marked as POST COMMIT
# To run them, please build the reference_model in a local "build" directory
# (as per the README) and run them using: pytest -m "postcommit"

# Location of reference model binaries
REF_MODEL_BUILD_PATH = Path(__file__).resolve().parents[2] / "build" / "reference_model"
GENERATE_LIB = "libtosa_reference_generate_lib.so"
GENERATE_LIB_PATH = REF_MODEL_BUILD_PATH / GENERATE_LIB

TEST_DIR = Path(__file__).parent


@pytest.mark.postcommit
def test_generate_lib_built():
    """First test to check the library has been built."""
    assert GENERATE_LIB_PATH.is_file()


@pytest.mark.postcommit
def test_checker_generate_load_fail():
    with pytest.raises(GenerateError) as excinfo:
        GenerateLibrary(Path("/place-that-does-not-exist"))
    assert str(excinfo.value).startswith("Could not find generate library")


@pytest.mark.postcommit
def test_checker_generate_load():
    glib = GenerateLibrary(GENERATE_LIB_PATH)
    assert glib


JSON_DATAGEN_DOT_PRODUCT = {
    "tosa_file": "test.json",
    "ifm_name": ["input-0", "input-1"],
    "ifm_file": ["input-0.npy", "input-1.npy"],
    "ofm_name": ["result-0"],
    "ofm_file": ["result-0.npy"],
    "meta": {
        "data_gen": {
            "version": "0.1",
            "tensors": {
                "input-0": {
                    "generator": "DOT_PRODUCT",
                    "data_type": "FP32",
                    "input_type": "VARIABLE",
                    "shape": [3, 5, 4],
                    "input_pos": 0,
                    "op": "MATMUL",
                    "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"},
                },
                "input-1": {
                    "generator": "DOT_PRODUCT",
                    "data_type": "FP32",
                    "input_type": "VARIABLE",
                    "shape": [3, 4, 6],
                    "input_pos": 1,
                    "op": "MATMUL",
                    "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"},
                },
            },
        }
    },
}


@pytest.mark.postcommit
def test_generate_dot_product_check():
    glib = GenerateLibrary(GENERATE_LIB_PATH)
    assert glib

    json_config = JSON_DATAGEN_DOT_PRODUCT
    glib.set_config(json_config)

    glib.write_numpy_files(TEST_DIR)

    # Test the files exist and are the expected numpy files
    for f, n in zip(json_config["ifm_file"], json_config["ifm_name"]):
        file = TEST_DIR / f
        assert file.is_file()
        arr = np.load(file)
        assert arr.shape == tuple(
            json_config["meta"]["data_gen"]["tensors"][n]["shape"]
        )
        assert arr.dtype == np.float32
        file.unlink()


@pytest.mark.postcommit
def test_generate_dot_product_check_fail_names():
    glib = GenerateLibrary(GENERATE_LIB_PATH)
    assert glib

    # Fix up the JSON to have the wrong names
    json_config = JSON_DATAGEN_DOT_PRODUCT.copy()
    json_config["ifm_name"] = ["not-input0", "not-input1"]
    glib.set_config(json_config)

    with pytest.raises(GenerateError) as excinfo:
        glib.write_numpy_files(TEST_DIR)
    info = str(excinfo.value).split("\n")
    for i, n in enumerate(json_config["ifm_name"]):
        assert info[i].startswith(f"ERROR: Failed to create data for tensor {n}")

    for f in json_config["ifm_file"]:
        file = TEST_DIR / f
        assert not file.is_file()


@pytest.mark.postcommit
def test_generate_tensor_data_check():
    glib = GenerateLibrary(GENERATE_LIB_PATH)
    assert glib

    json_config = JSON_DATAGEN_DOT_PRODUCT["meta"]["data_gen"]

    for n in JSON_DATAGEN_DOT_PRODUCT["ifm_name"]:
        arr = glib.get_tensor_data(n, json_config)

        assert arr.shape == tuple(json_config["tensors"][n]["shape"])
        assert arr.dtype == np.float32