aboutsummaryrefslogtreecommitdiff
path: root/verif/tests/test_tosa_datagenerator.py
blob: ba0235c1f1afd937bba6f93c0915aed8802e403a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Tests for the python interface to the data generator library."""
# Copyright (c) 2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path

import numpy as np
import pytest
from generator.datagenerator import GenerateError
from generator.datagenerator import GenerateLibrary

# NOTE: These tests are marked as POST COMMIT
# To run them, please build the reference_model in a local "build" directory
# (as per the README) and run them using: pytest -m "postcommit"

# Location of reference model binaries
REF_MODEL_BUILD_PATH = Path(__file__).resolve().parents[2] / "build" / "reference_model"
GENERATE_LIB = "libtosa_reference_generate_lib.so"
GENERATE_LIB_PATH = REF_MODEL_BUILD_PATH / GENERATE_LIB

TEST_DIR = Path(__file__).parent


@pytest.mark.postcommit
def test_generate_lib_built():
    """First test to check the library has been built."""
    assert GENERATE_LIB_PATH.is_file()


@pytest.mark.postcommit
def test_checker_generate_load_fail():
    with pytest.raises(GenerateError) as excinfo:
        GenerateLibrary(Path("/place-that-does-not-exist"))
    assert str(excinfo.value).startswith("Could not find generate library")


@pytest.mark.postcommit
def test_checker_generate_load():
    glib = GenerateLibrary(GENERATE_LIB_PATH)
    assert glib


JSON_DATAGEN_DOT_PRODUCT = {
    "tosa_file": "test.json",
    "ifm_name": ["input-0", "input-1"],
    "ifm_file": ["input-0.npy", "input-1.npy"],
    "ofm_name": ["result-0"],
    "ofm_file": ["result-0.npy"],
    "meta": {
        "data_gen": {
            "version": "0.1",
            "tensors": {
                "input-0": {
                    "generator": "DOT_PRODUCT",
                    "data_type": "FP32",
                    "input_type": "VARIABLE",
                    "shape": [3, 5, 4],
                    "input_pos": 0,
                    "op": "MATMUL",
                    "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"},
                },
                "input-1": {
                    "generator": "DOT_PRODUCT",
                    "data_type": "FP32",
                    "input_type": "VARIABLE",
                    "shape": [3, 4, 6],
                    "input_pos": 1,
                    "op": "MATMUL",
                    "dot_product_info": {"s": 0, "ks": 4, "acc_type": "FP32"},
                },
            },
        }
    },
}


@pytest.mark.postcommit
def test_generate_dot_product_check():
    glib = GenerateLibrary(GENERATE_LIB_PATH)
    assert glib

    json_config = JSON_DATAGEN_DOT_PRODUCT
    glib.set_config(json_config)

    glib.write_numpy_files(TEST_DIR)

    # Test the files exist and are the expected numpy files
    for f, n in zip(json_config["ifm_file"], json_config["ifm_name"]):
        file = TEST_DIR / f
        assert file.is_file()
        arr = np.load(file)
        assert arr.shape == tuple(
            json_config["meta"]["data_gen"]["tensors"][n]["shape"]
        )
        assert arr.dtype == np.float32
        file.unlink()


@pytest.mark.postcommit
def test_generate_dot_product_check_fail_names():
    glib = GenerateLibrary(GENERATE_LIB_PATH)
    assert glib

    # Fix up the JSON to have the wrong names
    json_config = JSON_DATAGEN_DOT_PRODUCT.copy()
    json_config["ifm_name"] = ["not-input0", "not-input1"]
    glib.set_config(json_config)

    with pytest.raises(GenerateError) as excinfo:
        glib.write_numpy_files(TEST_DIR)
    info = str(excinfo.value).split("\n")
    for i, n in enumerate(json_config["ifm_name"]):
        assert info[i].startswith(f"ERROR: Failed to create data for tensor {n}")

    for f in json_config["ifm_file"]:
        file = TEST_DIR / f
        assert not file.is_file()