aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_utils.cc')
-rw-r--r--reference_model/src/verify/verify_utils.cc91
1 files changed, 84 insertions, 7 deletions
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 4ae245b..57cb50a 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -20,6 +20,7 @@
#include <algorithm>
#include <cfloat>
#include <map>
+#include <string>
namespace tosa
{
@@ -232,16 +233,22 @@ static_assert(std::numeric_limits<double>::is_iec559,
"verification is invalid");
template <typename OutType>
-bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorBound)
+bool tosaCheckFloatBound(
+ OutType testValue, double referenceValue, double errorBound, double& resultDifference, std::string& resultWarning)
{
// Both must be NaNs to be correct
if (std::isnan(referenceValue) || std::isnan(testValue))
{
if (std::isnan(referenceValue) && std::isnan(testValue))
{
+ resultDifference = 0.0;
return true;
}
- WARNING("[Verifier][Bound] Non-matching NaN values - ref (%g) versus test (%g).", referenceValue, testValue);
+ char buff[200];
+ snprintf(buff, 200, "Non-matching NaN values - ref (%g) versus test (%g).", referenceValue,
+ static_cast<double>(testValue));
+ resultWarning.assign(buff);
+ resultDifference = std::numeric_limits<double>::quiet_NaN();
return false;
}
@@ -307,16 +314,86 @@ bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorB
// And finally... Do the comparison.
double testValue64 = static_cast<double>(testValue);
bool withinBound = testValue64 >= referenceMin && testValue64 <= referenceMax;
+ resultDifference = testValue64 - referenceValue;
if (!withinBound)
{
- WARNING("[Verifier][Bound] value %.*g is not in error bound %.*g range (%.*g <= ref %.*g <= %.*g).", DBL_DIG,
- testValue64, DBL_DIG, errorBound, DBL_DIG, referenceMin, DBL_DIG, referenceValue, DBL_DIG,
- referenceMax);
+ char buff[300];
+ snprintf(buff, 300,
+ "value %.*g has a difference of %.*g compared to an error bound of +/- %.*g (range: %.*g <= ref %.*g "
+ "<= %.*g).",
+ DBL_DIG, testValue64, DBL_DIG, resultDifference, DBL_DIG, errorBound, DBL_DIG, referenceMin, DBL_DIG,
+ referenceValue, DBL_DIG, referenceMax);
+ resultWarning.assign(buff);
}
return withinBound;
}
+template <typename OutType>
+bool validateData(const double* referenceData,
+ const double* boundsData,
+ const OutType* implementationData,
+ const std::vector<int32_t>& shape,
+ const std::string& modeStr,
+ const void* cfgPtr,
+ double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr))
+{
+ const size_t T = static_cast<size_t>(numElements(shape));
+ TOSA_REF_REQUIRE(T > 0, "Invalid shape for reference tensor");
+ TOSA_REF_REQUIRE(referenceData != nullptr, "Missing data for reference tensor");
+ TOSA_REF_REQUIRE(implementationData != nullptr, "Missing data for implementation tensor");
+ // NOTE: Bounds data tensor is allowed to be null as it may not be needed
+ TOSA_REF_REQUIRE(cfgPtr != nullptr, "Missing config for validation");
+ TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation");
+
+ std::string warning, worstWarning;
+ double difference, worstDifference = 0.0;
+ size_t worstPosition;
+ bool compliant = true;
+
+ for (size_t i = 0; i < T; ++i)
+ {
+ double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
+ double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
+ bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
+ if (!valid)
+ {
+ compliant = false;
+ if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference))
+ {
+ worstPosition = i;
+ worstDifference = difference;
+ worstWarning.assign(warning);
+ if (std::isnan(difference))
+ {
+ // Worst case is difference in NaN
+ break;
+ }
+ }
+ }
+ }
+ if (!compliant)
+ {
+ auto pos = indexToPosition(worstPosition, shape);
+ WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(),
+ worstWarning.c_str());
+ }
+ return compliant;
+}
+
// Instantiate the needed check functions
-template bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound);
-template bool tosaCheckFloatBound(half_float::half testValue, double referenceValue, double errorBound);
+template bool validateData(const double* referenceData,
+ const double* boundsData,
+ const float* implementationData,
+ const std::vector<int32_t>& shape,
+ const std::string& modeStr,
+ const void* cfgPtr,
+ double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr));
+template bool validateData(const double* referenceData,
+ const double* boundsData,
+ const half_float::half* implementationData,
+ const std::vector<int32_t>& shape,
+ const std::string& modeStr,
+ const void* cfgPtr,
+ double (*calcErrorBound)(double referenceValue, double boundsValue, const void* cfgPtr));
+
} // namespace TosaReference