aboutsummaryrefslogtreecommitdiff
path: root/verif/tests/test_tosa_verifier.py
blob: a29f983357b319fb85cfc86ae5cdd8e41c0896bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Tests for the python interface to the verifier library."""
# Copyright (c) 2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path

import numpy as np
import pytest
from checker.verifier import VerifierError
from checker.verifier import VerifierLibrary

# NOTE: These tests are marked as POST COMMIT
# To run them, please build the reference_model in a local "build" directory
# (as per the README) and run them using: pytest -m "postcommit"

# Location of reference model binaries
REF_MODEL_BUILD_PATH = Path(__file__).resolve().parents[2] / "build" / "reference_model"
VERIFIER_LIB = "libtosa_reference_verify_lib.so"
VERIFIER_LIB_PATH = REF_MODEL_BUILD_PATH / VERIFIER_LIB


@pytest.mark.postcommit
def test_verifier_lib_built():
    """First test to check the library has been built."""
    assert VERIFIER_LIB_PATH.is_file()


@pytest.mark.postcommit
def test_checker_verifier_load_fail():
    with pytest.raises(VerifierError) as excinfo:
        VerifierLibrary(Path("/place-that-does-not-exist"))
    assert str(excinfo.value).startswith("Could not find verify library")


@pytest.mark.postcommit
def test_checker_verifier_load():
    vlib = VerifierLibrary(VERIFIER_LIB_PATH)
    assert vlib


JSON_COMPLIANCE_DOT_PRODUCT = {
    "version": "0.1",
    "tensors": {
        "output1": {
            "mode": "DOT_PRODUCT",
            "data_type": "FP32",
            "dot_product_info": {"ks": 1000, "s": 0},
        }
    },
}


@pytest.mark.postcommit
def test_checker_verifier_dot_product_check():
    vlib = VerifierLibrary(VERIFIER_LIB_PATH)
    assert vlib

    imp_arr = np.zeros((10, 10, 10), dtype=np.float32)
    ref_arr = np.zeros((10, 10, 10), dtype=np.float64)
    bnd_arr = np.zeros((10, 10, 10), dtype=np.float64)

    json_config = JSON_COMPLIANCE_DOT_PRODUCT

    ret = vlib.verify_data("output1", json_config, imp_arr, ref_arr, bnd_arr)
    assert ret


@pytest.mark.postcommit
def test_checker_verifier_dot_product_check_fail():
    vlib = VerifierLibrary(VERIFIER_LIB_PATH)
    assert vlib

    imp_arr = np.zeros((10, 10, 10), dtype=np.float32)
    ref_arr = np.ones((10, 10, 10), dtype=np.float64)
    bnd_arr = np.zeros((10, 10, 10), dtype=np.float64)

    json_config = JSON_COMPLIANCE_DOT_PRODUCT

    ret = vlib.verify_data("output1", json_config, imp_arr, ref_arr, bnd_arr)
    assert not ret