aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify.cc
blob: 18abf0be4bba0a87e16131abb644a67cfe778ddb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright (c) 2023, ARM Limited.
//
//    Licensed under the Apache License, Version 2.0 (the "License");
//    you may not use this file except in compliance with the License.
//    You may obtain a copy of the License at
//
//         http://www.apache.org/licenses/LICENSE-2.0
//
//    Unless required by applicable law or agreed to in writing, software
//    distributed under the License is distributed on an "AS IS" BASIS,
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//    See the License for the specific language governing permissions and
//    limitations under the License.
//===----------------------------------------------------------------------===//
//
// Verification functionality as per TOSA Specification
// Output Verification : Section 1.8.2
//
//===----------------------------------------------------------------------===//

#include "verify.h"

#include <half.hpp>

#include <cmath>
#include <numeric>
#include <optional>
#include <type_traits>

#define REQUIRE(COND)                                                                                                  \
    if (!(COND))                                                                                                       \
    {                                                                                                                  \
        return false;                                                                                                  \
    }

namespace
{
// Accumulator precision
template <typename T>
struct AccPrecision;
template <>
struct AccPrecision<float>
{
    static constexpr double precision = (double)(1 << 24);
};
template <>
struct AccPrecision<half_float::half>
{
    static constexpr double precision = (double)(1 << 11);
};

// Generic element validation function
template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
std::optional<double> validate_element(double ref, double bnd, AccType imp, size_t KS)
{
    double err    = 0.0;
    bool is_valid = true;

    if (bnd == 0.0)
    {
        is_valid = (ref == 0.0) && (imp == 0.0);
        err      = 0.0;
    }
    else
    {    // bnd > 0.0
        const double imp_fp64      = static_cast<double>(imp);
        const double acc_prec_fp64 = AccPrecision<AccType>::precision;
        err                        = (imp_fp64 - ref) * acc_prec_fp64 / bnd;
        is_valid                   = std::abs(err) <= KS;
    }

    return is_valid ? std::optional(err) : std::nullopt;
}

// Generic data validation function
template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
bool validate_data(const double* ref, const double* bnd, const AccType* imp, size_t T, size_t KS, int32_t S)
{
    double out_err_sum   = 0.0;
    double out_err_sumsq = 0.0;

    for (size_t i = 0; i < T; ++i)
    {
        auto out_err = validate_element<AccType>(ref[i], bnd[i], imp[i], KS);
        REQUIRE(out_err);
        out_err_sum += out_err.value();
        out_err_sumsq += out_err.value() * out_err.value();
    }

    return tosa_validate_output_error(out_err_sum, out_err_sumsq, T, KS, S);
}

// Convert std::optional to CheckResult
CheckResult from_optional(const std::optional<double>& res)
{
    if (res)
        return { true, *res };
    else
        return { false, 0.0 };
}
}    // namespace

extern "C"
{

    CheckResult tosa_validate_element_accfp32(double ref, double bnd, float imp, size_t KS)
    {
        auto err = validate_element<float>(ref, bnd, imp, KS);
        return from_optional(err);
    }

    bool tosa_validate_output_error(double err_sum, double err_sum_sq, size_t T, size_t KS, int S)
    {
        if (S != 1 && S != 2)
        {
            // Check error bias magnitude for data sets S which are not positive biased
            REQUIRE(std::abs(err_sum) <= 2 * sqrt(KS * T));
        }
        // Check error variance magnitude
        REQUIRE(err_sum_sq <= 0.4 * KS * T);

        return true;
    }

    bool tosa_validate_data_fp32(const double* ref, const double* bnd, const float* imp, size_t T, size_t KS, int S)
    {
        return validate_data<float>(ref, bnd, imp, T, KS, S);
    }

}    // extern "C"
#undef REQUIRE