aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify
diff options
context:
space:
mode:
authorevacha01 <evan.chandler@arm.com>2024-02-07 11:21:55 +0000
committerevacha01 <evan.chandler@arm.com>2024-03-07 12:06:38 +0000
commit9c96eefbaca6c85be79529bce7ff04fd7dfe3a0d (patch)
tree55647ee0216800b621bd0b27277c6f895929ef3d /reference_model/src/verify
parent6e1e2bc06bff785e87577f24064bbc846300f8fd (diff)
downloadreference_model-9c96eefbaca6c85be79529bce7ff04fd7dfe3a0d.tar.gz
FULL data gen mode for FP16
Signed-off-by: evacha01 <evan.chandler@arm.com> Change-Id: I81bb322132daf25328a40342edc62d8e1db9edd6
Diffstat (limited to 'reference_model/src/verify')
-rw-r--r--reference_model/src/verify/verify_abs_error.cc13
-rw-r--r--reference_model/src/verify/verify_utils.cc25
2 files changed, 26 insertions, 12 deletions
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc
index 125045e..64f86a3 100644
--- a/reference_model/src/verify/verify_abs_error.cc
+++ b/reference_model/src/verify/verify_abs_error.cc
@@ -30,12 +30,17 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg
{
const auto cfg = reinterpret_cast<const AbsErrorVerifyInfo*>(cfgPtr);
- double valBound = std::abs(referenceValue) * boundsValue;
- if (cfg->lowerBound > 0)
+ double errorBound = 0.0;
+ if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0)
{
- valBound = std::max(cfg->lowerBound, valBound);
+ double valBound = std::abs(referenceValue) * boundsValue;
+ if (cfg->lowerBound > 0)
+ {
+ valBound = std::max(cfg->lowerBound, valBound);
+ }
+ errorBound = exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound;
}
- return exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound;
+ return errorBound;
}
} // namespace
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 50a98e5..d4657b3 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -356,21 +356,23 @@ bool validateData(const double* referenceData,
TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation");
std::string warning, worstWarning;
- double difference, worstDifference = 0.0;
- size_t worstPosition;
- bool compliant = true;
+ double worstDifference = 0.0;
+ // Set to invalid index
+ size_t worstIndex = T;
+ bool compliant = true;
for (size_t i = 0; i < T; ++i)
{
- double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
- double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
- bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
+ double difference = 0.0;
+ double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
+ double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
+ bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
if (!valid)
{
compliant = false;
if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference))
{
- worstPosition = i;
+ worstIndex = i;
worstDifference = difference;
worstWarning.assign(warning);
if (std::isnan(difference))
@@ -379,11 +381,18 @@ bool validateData(const double* referenceData,
break;
}
}
+ else if (std::abs(difference) == 0.0)
+ {
+ auto pos = indexToPosition(i, shape);
+ WARNING("[Verifier][%s] Invalid error bound, no difference found. Location: %s", modeStr.c_str(),
+ positionToString(pos).c_str());
+ return false;
+ }
}
}
if (!compliant)
{
- auto pos = indexToPosition(worstPosition, shape);
+ auto pos = indexToPosition(worstIndex, shape);
WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(),
worstWarning.c_str());
}