diff options
Diffstat (limited to 'reference_model/src/verify/verify_abs_error.cc')
-rw-r--r-- | reference_model/src/verify/verify_abs_error.cc | 61 |
1 files changed, 26 insertions, 35 deletions
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index 5005dcf..a7b7bc2 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -26,59 +26,50 @@ namespace TosaReference namespace { -template <typename OutDtype> -bool validateData(const double* ref, - const double* bnd, - const OutDtype* imp, - const std::vector<int32_t>& shape, - const AbsErrorVerifyInfo& cfg) +template <typename OutType> +double calcErrorBound(double referenceValue, double boundsValue, const void* cfgPtr) { - const size_t T = static_cast<size_t>(numElements(shape)); - TOSA_REF_REQUIRE(T > 0, "[AE] Invalid shape for reference tensor"); + const auto cfg = reinterpret_cast<const AbsErrorVerifyInfo*>(cfgPtr); - for (size_t i = 0; i < T; ++i) + double valBound = std::abs(referenceValue) * boundsValue; + if (cfg->lowerBound > 0) { - double valBound = std::abs(ref[i]) * bnd[i]; - if (cfg.lowerBound > 0) - { - valBound = std::max(cfg.lowerBound, valBound); - } - double errBound = exp2(-AccPrecision<OutDtype>::normal_frac) * valBound; - bool valid = tosaCheckFloatBound(imp[i], ref[i], errBound); - if (!valid) - { - auto pos = indexToPosition(i, shape); - WARNING("[Verifier][AE] Location %s", positionToString(pos).c_str()); - return false; - } + valBound = std::max(cfg->lowerBound, valBound); } - return true; + return exp2(-AccPrecision<OutType>::normal_frac) * valBound; } } // namespace -bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const AbsErrorVerifyInfo& aeInfo) + +bool verifyAbsError(const CTensor* referenceTensor, + const CTensor* boundsTensor, + const CTensor* implementationTensor, + const AbsErrorVerifyInfo& aeInfo) { // Validate that tensors are provided - TOSA_REF_REQUIRE(ref != nullptr, "[AE] Reference tensor is missing"); - TOSA_REF_REQUIRE(refBnd != nullptr, "[AE] Reference bounds tensor is missing"); - TOSA_REF_REQUIRE(imp != nullptr, "[AE] Implementation tensor is missing"); + TOSA_REF_REQUIRE(referenceTensor != nullptr, "[AE] Reference tensor is missing"); + TOSA_REF_REQUIRE(boundsTensor != nullptr, "[AE] Reference bounds tensor is missing"); + TOSA_REF_REQUIRE(implementationTensor != nullptr, "[AE] Implementation tensor is missing"); - const std::vector<int32_t> refShape(ref->shape, ref->shape + ref->num_dims); + const std::vector<int32_t> refShape(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims); - const double* refData = reinterpret_cast<const double*>(ref->data); - const double* refBndData = reinterpret_cast<const double*>(refBnd->data); + const double* refData = reinterpret_cast<const double*>(referenceTensor->data); + const double* refBndData = reinterpret_cast<const double*>(boundsTensor->data); TOSA_REF_REQUIRE(refData != nullptr && refBndData != nullptr, "[AE] Missing data for reference or bounds tensors"); - switch (imp->data_type) + const std::string modeStr = "AE"; + + switch (implementationTensor->data_type) { case tosa_datatype_fp32_t: { - const auto* impData = reinterpret_cast<const float*>(imp->data); + const auto* impData = reinterpret_cast<const float*>(implementationTensor->data); TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); - return validateData(refData, refBndData, impData, refShape, aeInfo); + return validateData(refData, refBndData, impData, refShape, modeStr, &aeInfo, &calcErrorBound<float>); } case tosa_datatype_fp16_t: { - const auto* impData = reinterpret_cast<const half_float::half*>(imp->data); + const auto* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data); TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); - return validateData(refData, refBndData, impData, refShape, aeInfo); + return validateData(refData, refBndData, impData, refShape, modeStr, &aeInfo, + &calcErrorBound<half_float::half>); } default: WARNING("[Verifier][AE] Data-type not supported."); |