diff options
Diffstat (limited to 'reference_model/src/verify/verify_ulp.cc')
-rw-r--r-- | reference_model/src/verify/verify_ulp.cc | 60 |
1 files changed, 16 insertions, 44 deletions
diff --git a/reference_model/src/verify/verify_ulp.cc b/reference_model/src/verify/verify_ulp.cc index 13bf0a9..8bae6e6 100644 --- a/reference_model/src/verify/verify_ulp.cc +++ b/reference_model/src/verify/verify_ulp.cc @@ -27,24 +27,27 @@ namespace TosaReference namespace { template <typename OutType> -bool tosaCheckULP(OutType testValue, double referenceValue, double ulpNum) +double calcErrorBound(double referenceValue, double boundsValue, const void* cfgPtr) { - double errorBound = 0.0; + const auto cfg = reinterpret_cast<const UlpVerifyInfo*>(cfgPtr); + unused(boundsValue); + + double errBound = 0.0; if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0) { // Find the exponent of the reference value. - int32_t referenceExponent = ilog2(std::abs(referenceValue)); + int32_t refExponent = ilog2(std::abs(referenceValue)); // Work out the values magnitude - by raising 2 to the power of the // exponent and taking the normalized minimum for denormal values - const double referencePower2 = std::max(exp2(referenceExponent), AccPrecision<OutType>::normal_min); + const double refPower2 = std::max(exp2(refExponent), AccPrecision<OutType>::normal_min); // Get the value of changing the last bit - by shifting the least significant bit to this magnitude // i.e. the ULP. - double ulpValue = referencePower2 * exp2(-AccPrecision<OutType>::normal_frac); + double ulpValue = refPower2 * exp2(-AccPrecision<OutType>::normal_frac); - errorBound = ulpValue * ulpNum; + errBound = ulpValue * cfg->ulp; } - return tosaCheckFloatBound(testValue, referenceValue, errorBound); + return errBound; } } // namespace @@ -54,56 +57,25 @@ bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTens TOSA_REF_REQUIRE(referenceTensor != nullptr, "[ULP] Reference tensor is missing"); TOSA_REF_REQUIRE(implementationTensor != nullptr, "[ULP] Implementation tensor is missing"); - // Get number of elements const std::vector<int32_t> refShape(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims); - const auto elementCount = numElements(refShape); - TOSA_REF_REQUIRE(elementCount > 0, "[ULP] Invalid shape for reference tensor"); - const double ulp = ulpInfo.ulp; const auto* refData = reinterpret_cast<const double*>(referenceTensor->data); TOSA_REF_REQUIRE(refData != nullptr, "[ULP] Missing data for reference"); - const auto* refDataEnd = std::next(refData, elementCount); + + const std::string modeStr = "ULP"; + switch (implementationTensor->data_type) { case tosa_datatype_fp32_t: { const auto* impData = reinterpret_cast<const float*>(implementationTensor->data); TOSA_REF_REQUIRE(impData != nullptr, "[ULP] Missing data for implementation"); - // Use mismatch to get the location of the first unequal value - auto pair = std::mismatch(refData, refDataEnd, impData, std::next(impData, elementCount), - [ulp](const auto& referenceValue, const auto& implementationValue) { - return tosaCheckULP(implementationValue, referenceValue, ulp); - }); - if (std::get<0>(pair) == refDataEnd) - { - // No mismatch found - return true; - } - else - { - auto pos = indexToPosition(std::get<0>(pair) - refData, refShape); - WARNING("[Verfier][ULP] Location %s", positionToString(pos).c_str()); - return false; - } + return validateData(refData, nullptr, impData, refShape, modeStr, &ulpInfo, &calcErrorBound<float>); } case tosa_datatype_fp16_t: { const auto* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data); TOSA_REF_REQUIRE(impData != nullptr, "[ULP] Missing data for implementation"); - // Use mismatch to get the location of the first unequal value - auto pair = std::mismatch(refData, refDataEnd, impData, std::next(impData, elementCount), - [ulp](const auto& referenceValue, const auto& implementationValue) { - return tosaCheckULP(implementationValue, referenceValue, ulp); - }); - if (std::get<0>(pair) == refDataEnd) - { - // No mismatch found - return true; - } - else - { - auto pos = indexToPosition(std::get<0>(pair) - refData, refShape); - WARNING("[Verfier][ULP] Location %s", positionToString(pos).c_str()); - return false; - } + return validateData(refData, nullptr, impData, refShape, modeStr, &ulpInfo, + &calcErrorBound<half_float::half>); } default: WARNING("[Verifier][ULP] Data-type not supported."); |