diff options
Diffstat (limited to 'reference_model/src/verify/verify_abs_error.cc')
-rw-r--r-- | reference_model/src/verify/verify_abs_error.cc | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index b43da08..5aaa0ad 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -18,6 +18,7 @@ #include <type_traits> #include <utility> +#include "half.hpp" #include "verifiers.h" namespace TosaReference @@ -25,14 +26,15 @@ namespace TosaReference namespace { -bool validateData(const double* ref, const double* bnd, const float* imp, const std::vector<int32_t>& shape) +template <typename OutDtype> +bool validateData(const double* ref, const double* bnd, const OutDtype* imp, const std::vector<int32_t>& shape) { const size_t T = static_cast<size_t>(numElements(shape)); TOSA_REF_REQUIRE(T > 0, "[AE] Invalid shape for reference tensor"); for (size_t i = 0; i < T; ++i) { - double errBound = std::abs(ref[i]) * exp2(-AccPrecision<float>::normal_frac) * bnd[i]; + double errBound = std::abs(ref[i]) * exp2(-AccPrecision<OutDtype>::normal_frac) * bnd[i]; bool valid = tosaCheckFloatBound(imp[i], ref[i], errBound); if (!valid) { @@ -60,7 +62,12 @@ bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* im switch (imp->data_type) { case tosa_datatype_fp32_t: { - const float* impData = reinterpret_cast<const float*>(imp->data); + const auto* impData = reinterpret_cast<const float*>(imp->data); + TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); + return validateData(refData, refBndData, impData, refShape); + } + case tosa_datatype_fp16_t: { + const auto* impData = reinterpret_cast<const half_float::half*>(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); return validateData(refData, refBndData, impData, refShape); } |