diff options
Diffstat (limited to 'reference_model/src/verify/verify_abs_error.cc')
-rw-r--r-- | reference_model/src/verify/verify_abs_error.cc | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index 5aaa0ad..25ecae4 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,14 +27,23 @@ namespace TosaReference namespace { template <typename OutDtype> -bool validateData(const double* ref, const double* bnd, const OutDtype* imp, const std::vector<int32_t>& shape) +bool validateData(const double* ref, + const double* bnd, + const OutDtype* imp, + const std::vector<int32_t>& shape, + const AbsErrorVerifyInfo& cfg) { const size_t T = static_cast<size_t>(numElements(shape)); TOSA_REF_REQUIRE(T > 0, "[AE] Invalid shape for reference tensor"); for (size_t i = 0; i < T; ++i) { - double errBound = std::abs(ref[i]) * exp2(-AccPrecision<OutDtype>::normal_frac) * bnd[i]; + double valBound = std::abs(ref[i]) * bnd[i]; + if (cfg.lowerBound > 0) + { + valBound = std::max(cfg.lowerBound, valBound); + } + double errBound = exp2(-AccPrecision<OutDtype>::normal_frac) * valBound; bool valid = tosaCheckFloatBound(imp[i], ref[i], errBound); if (!valid) { @@ -46,7 +55,7 @@ bool validateData(const double* ref, const double* bnd, const OutDtype* imp, con return true; } } // namespace -bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* imp) +bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const AbsErrorVerifyInfo& aeInfo) { // Validate that tensors are provided TOSA_REF_REQUIRE(ref != nullptr, "[AE] Reference tensor is missing"); @@ -64,12 +73,12 @@ bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* im case tosa_datatype_fp32_t: { const auto* impData = reinterpret_cast<const float*>(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); - return validateData(refData, refBndData, impData, refShape); + return validateData(refData, refBndData, impData, refShape, aeInfo); } case tosa_datatype_fp16_t: { const auto* impData = reinterpret_cast<const half_float::half*>(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); - return validateData(refData, refBndData, impData, refShape); + return validateData(refData, refBndData, impData, refShape, aeInfo); } default: WARNING("[Verifier][AE] Data-type not supported."); |