From 08965d35f728d93d8b215753b4b270a422fe39c9 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Mon, 19 Feb 2024 13:57:21 +0000 Subject: Verifier - change to output largest error deviance Add general validateData function used by ABS_ERROR, ULP, RELATIVE and REDUCE_PRODUCT to find and output largest deviance from the error bounds. Clean up naming inconsistencies bewteen verify modes. Signed-off-by: Jeremy Johnson Change-Id: Ib903faf36f784cacae91edab61d8e489461a727c --- reference_model/src/verify/verify_utils.cc | 91 +++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 7 deletions(-) (limited to 'reference_model/src/verify/verify_utils.cc') diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index 4ae245b..57cb50a 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -20,6 +20,7 @@ #include #include #include +#include namespace tosa { @@ -232,16 +233,22 @@ static_assert(std::numeric_limits::is_iec559, "verification is invalid"); template -bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorBound) +bool tosaCheckFloatBound( + OutType testValue, double referenceValue, double errorBound, double& resultDifference, std::string& resultWarning) { // Both must be NaNs to be correct if (std::isnan(referenceValue) || std::isnan(testValue)) { if (std::isnan(referenceValue) && std::isnan(testValue)) { + resultDifference = 0.0; return true; } - WARNING("[Verifier][Bound] Non-matching NaN values - ref (%g) versus test (%g).", referenceValue, testValue); + char buff[200]; + snprintf(buff, 200, "Non-matching NaN values - ref (%g) versus test (%g).", referenceValue, + static_cast(testValue)); + resultWarning.assign(buff); + resultDifference = std::numeric_limits::quiet_NaN(); return false; } @@ -307,16 +314,86 @@ bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorB // And finally... Do the comparison. double testValue64 = static_cast(testValue); bool withinBound = testValue64 >= referenceMin && testValue64 <= referenceMax; + resultDifference = testValue64 - referenceValue; if (!withinBound) { - WARNING("[Verifier][Bound] value %.*g is not in error bound %.*g range (%.*g <= ref %.*g <= %.*g).", DBL_DIG, - testValue64, DBL_DIG, errorBound, DBL_DIG, referenceMin, DBL_DIG, referenceValue, DBL_DIG, - referenceMax); + char buff[300]; + snprintf(buff, 300, + "value %.*g has a difference of %.*g compared to an error bound of +/- %.*g (range: %.*g <= ref %.*g " + "<= %.*g).", + DBL_DIG, testValue64, DBL_DIG, resultDifference, DBL_DIG, errorBound, DBL_DIG, referenceMin, DBL_DIG, + referenceValue, DBL_DIG, referenceMax); + resultWarning.assign(buff); } return withinBound; } +template +bool validateData(const double* referenceData, + const double* boundsData, + const OutType* implementationData, + const std::vector& shape, + const std::string& modeStr, + const void* cfgPtr, + double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr)) +{ + const size_t T = static_cast(numElements(shape)); + TOSA_REF_REQUIRE(T > 0, "Invalid shape for reference tensor"); + TOSA_REF_REQUIRE(referenceData != nullptr, "Missing data for reference tensor"); + TOSA_REF_REQUIRE(implementationData != nullptr, "Missing data for implementation tensor"); + // NOTE: Bounds data tensor is allowed to be null as it may not be needed + TOSA_REF_REQUIRE(cfgPtr != nullptr, "Missing config for validation"); + TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation"); + + std::string warning, worstWarning; + double difference, worstDifference = 0.0; + size_t worstPosition; + bool compliant = true; + + for (size_t i = 0; i < T; ++i) + { + double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i]; + double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr); + bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning); + if (!valid) + { + compliant = false; + if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference)) + { + worstPosition = i; + worstDifference = difference; + worstWarning.assign(warning); + if (std::isnan(difference)) + { + // Worst case is difference in NaN + break; + } + } + } + } + if (!compliant) + { + auto pos = indexToPosition(worstPosition, shape); + WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(), + worstWarning.c_str()); + } + return compliant; +} + // Instantiate the needed check functions -template bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound); -template bool tosaCheckFloatBound(half_float::half testValue, double referenceValue, double errorBound); +template bool validateData(const double* referenceData, + const double* boundsData, + const float* implementationData, + const std::vector& shape, + const std::string& modeStr, + const void* cfgPtr, + double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr)); +template bool validateData(const double* referenceData, + const double* boundsData, + const half_float::half* implementationData, + const std::vector& shape, + const std::string& modeStr, + const void* cfgPtr, + double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr)); + } // namespace TosaReference -- cgit v1.2.1