diff options
Diffstat (limited to 'reference_model/src/verify/verify_dot_product.cc')
-rw-r--r-- | reference_model/src/verify/verify_dot_product.cc | 25 |
1 files changed, 16 insertions, 9 deletions
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc index 15de427..a036cba 100644 --- a/reference_model/src/verify/verify_dot_product.cc +++ b/reference_model/src/verify/verify_dot_product.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "func_debug.h" +#include "half.hpp" #include "verifiers.h" #include <cmath> @@ -25,13 +26,19 @@ namespace TosaReference namespace { // Generic element validation function -template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0> +template <typename AccType> std::optional<double> validateElement(size_t index, double ref, double bnd, AccType imp, size_t KS) { double err = 0.0; bool is_valid = true; - if (bnd == 0.0) + if (std::isinf(static_cast<AccType>(bnd))) + { + // dot product can overflow and there is no accuracy limit + is_valid = true; + err = 0.0; + } + else if (bnd == 0.0) { is_valid = (ref == 0.0) && (imp == 0.0); if (!is_valid) @@ -40,12 +47,6 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT } err = 0.0; } - else if (std::isinf(static_cast<AccType>(bnd))) - { - // dot product can overflow and there is no accuracy limit - is_valid = true; - err = 0.0; - } else { // 0.0 < bnd < infinity @@ -64,7 +65,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT } // Generic data validation function -template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0> +template <typename AccType> bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg) { const int32_t S = cfg.s; @@ -121,6 +122,12 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo); break; } + case tosa_datatype_fp16_t: { + const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data); + TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); + return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo); + break; + } default: { WARNING("[Verifier][DP] Data-type not supported."); break; |