From 718f347a2d886381de19420b5b5b99db8f2b7338 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 30 Nov 2023 14:18:19 +0000 Subject: Main Compliance FP16 support - generate and verify. FP16 support for all existing operators for compliance: * DOT_PRODUCT * ULP * EXACT * ABS_ERROR Signed-off-by: Jeremy Johnson Change-Id: I8d25448a793375b53880da3787d8f839767f02cf --- reference_model/src/verify/verify_abs_error.cc | 13 +++++++-- reference_model/src/verify/verify_dot_product.cc | 25 ++++++++++------ reference_model/src/verify/verify_exact.cc | 20 +++++++++---- reference_model/src/verify/verify_ulp.cc | 36 +++++++++++++++++++----- reference_model/src/verify/verify_utils.cc | 25 +++++++++------- reference_model/src/verify/verify_utils.h | 12 ++++++-- 6 files changed, 95 insertions(+), 36 deletions(-) (limited to 'reference_model/src/verify') diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index b43da08..5aaa0ad 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -18,6 +18,7 @@ #include #include +#include "half.hpp" #include "verifiers.h" namespace TosaReference @@ -25,14 +26,15 @@ namespace TosaReference namespace { -bool validateData(const double* ref, const double* bnd, const float* imp, const std::vector& shape) +template +bool validateData(const double* ref, const double* bnd, const OutDtype* imp, const std::vector& shape) { const size_t T = static_cast(numElements(shape)); TOSA_REF_REQUIRE(T > 0, "[AE] Invalid shape for reference tensor"); for (size_t i = 0; i < T; ++i) { - double errBound = std::abs(ref[i]) * exp2(-AccPrecision::normal_frac) * bnd[i]; + double errBound = std::abs(ref[i]) * exp2(-AccPrecision::normal_frac) * bnd[i]; bool valid = tosaCheckFloatBound(imp[i], ref[i], errBound); if (!valid) { @@ -60,7 +62,12 @@ bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* im switch (imp->data_type) { case tosa_datatype_fp32_t: { - const float* impData = reinterpret_cast(imp->data); + const auto* impData = reinterpret_cast(imp->data); + TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); + return validateData(refData, refBndData, impData, refShape); + } + case tosa_datatype_fp16_t: { + const auto* impData = reinterpret_cast(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation"); return validateData(refData, refBndData, impData, refShape); } diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc index 15de427..a036cba 100644 --- a/reference_model/src/verify/verify_dot_product.cc +++ b/reference_model/src/verify/verify_dot_product.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "func_debug.h" +#include "half.hpp" #include "verifiers.h" #include @@ -25,13 +26,19 @@ namespace TosaReference namespace { // Generic element validation function -template , int> = 0> +template std::optional validateElement(size_t index, double ref, double bnd, AccType imp, size_t KS) { double err = 0.0; bool is_valid = true; - if (bnd == 0.0) + if (std::isinf(static_cast(bnd))) + { + // dot product can overflow and there is no accuracy limit + is_valid = true; + err = 0.0; + } + else if (bnd == 0.0) { is_valid = (ref == 0.0) && (imp == 0.0); if (!is_valid) @@ -40,12 +47,6 @@ std::optional validateElement(size_t index, double ref, double bnd, AccT } err = 0.0; } - else if (std::isinf(static_cast(bnd))) - { - // dot product can overflow and there is no accuracy limit - is_valid = true; - err = 0.0; - } else { // 0.0 < bnd < infinity @@ -64,7 +65,7 @@ std::optional validateElement(size_t index, double ref, double bnd, AccT } // Generic data validation function -template , int> = 0> +template bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg) { const int32_t S = cfg.s; @@ -121,6 +122,12 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* return validateData(refData, refBndData, impData, static_cast(T), dpInfo); break; } + case tosa_datatype_fp16_t: { + const half_float::half* impData = reinterpret_cast(imp->data); + TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); + return validateData(refData, refBndData, impData, static_cast(T), dpInfo); + break; + } default: { WARNING("[Verifier][DP] Data-type not supported."); break; diff --git a/reference_model/src/verify/verify_exact.cc b/reference_model/src/verify/verify_exact.cc index 36b4ec9..971df9c 100644 --- a/reference_model/src/verify/verify_exact.cc +++ b/reference_model/src/verify/verify_exact.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "func_debug.h" +#include "half.hpp" #include "verifiers.h" #include namespace { -bool exact_fp32(const double& referenceValue, const float& implementationValue) +template +bool exact_fp(const double& referenceValue, const OutDtype& implementationValue) { return std::isnan(referenceValue) ? std::isnan(implementationValue) : (referenceValue == implementationValue); } @@ -38,16 +40,24 @@ bool verifyExact(const CTensor* referenceTensor, const CTensor* implementationTe numElements(std::vector(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims)); TOSA_REF_REQUIRE(elementCount > 0, "[E] Invalid shape for reference tensor"); + TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64"); + const auto* refData = reinterpret_cast(referenceTensor->data); + TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference"); + switch (implementationTensor->data_type) { case tosa_datatype_fp32_t: { - TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64"); - const auto* refData = reinterpret_cast(referenceTensor->data); - TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference"); const auto* impData = reinterpret_cast(implementationTensor->data); TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation"); auto result = std::equal(refData, std::next(refData, elementCount), impData, - std::next(impData, elementCount), exact_fp32); + std::next(impData, elementCount), exact_fp); + return result; + } + case tosa_datatype_fp16_t: { + const auto* impData = reinterpret_cast(implementationTensor->data); + TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation"); + auto result = std::equal(refData, std::next(refData, elementCount), impData, + std::next(impData, elementCount), exact_fp); return result; } default: diff --git a/reference_model/src/verify/verify_ulp.cc b/reference_model/src/verify/verify_ulp.cc index 6e78b96..1b38fe6 100644 --- a/reference_model/src/verify/verify_ulp.cc +++ b/reference_model/src/verify/verify_ulp.cc @@ -18,6 +18,7 @@ #include #include +#include "half.hpp" #include "verifiers.h" namespace TosaReference @@ -25,7 +26,8 @@ namespace TosaReference namespace { -bool tosaCheckULP(float testValue, double referenceValue, double ulpNum) +template +bool tosaCheckULP(OutType testValue, double referenceValue, double ulpNum) { double errorBound = 0.0; if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0) @@ -35,10 +37,10 @@ bool tosaCheckULP(float testValue, double referenceValue, double ulpNum) // Work out the values magnitude - by raising 2 to the power of the // exponent and taking the normalized minimum for denormal values - const double referencePower2 = std::max(exp2(referenceExponent), AccPrecision::normal_min); + const double referencePower2 = std::max(exp2(referenceExponent), AccPrecision::normal_min); // Get the value of changing the last bit - by shifting the least significant bit to this magnitude // i.e. the ULP. - double ulpValue = referencePower2 * exp2(-AccPrecision::normal_frac); + double ulpValue = referencePower2 * exp2(-AccPrecision::normal_frac); errorBound = ulpValue * ulpNum; } @@ -57,15 +59,35 @@ bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTens const auto elementCount = numElements(refShape); TOSA_REF_REQUIRE(elementCount > 0, "[ULP] Invalid shape for reference tensor"); - const double ulp = ulpInfo.ulp; + const double ulp = ulpInfo.ulp; + const auto* refData = reinterpret_cast(referenceTensor->data); + TOSA_REF_REQUIRE(refData != nullptr, "[ULP] Missing data for reference"); + const auto* refDataEnd = std::next(refData, elementCount); switch (implementationTensor->data_type) { case tosa_datatype_fp32_t: { - const auto* refData = reinterpret_cast(referenceTensor->data); - TOSA_REF_REQUIRE(refData != nullptr, "[ULP] Missing data for reference"); const auto* impData = reinterpret_cast(implementationTensor->data); TOSA_REF_REQUIRE(impData != nullptr, "[ULP] Missing data for implementation"); - const auto* refDataEnd = std::next(refData, elementCount); + // Use mismatch to get the location of the first unequal value + auto pair = std::mismatch(refData, refDataEnd, impData, std::next(impData, elementCount), + [ulp](const auto& referenceValue, const auto& implementationValue) { + return tosaCheckULP(implementationValue, referenceValue, ulp); + }); + if (std::get<0>(pair) == refDataEnd) + { + // No mismatch found + return true; + } + else + { + auto pos = indexToPosition(std::get<0>(pair) - refData, refShape); + WARNING("[Verfier][ULP] Location %s", positionToString(pos).c_str()); + return false; + } + } + case tosa_datatype_fp16_t: { + const auto* impData = reinterpret_cast(implementationTensor->data); + TOSA_REF_REQUIRE(impData != nullptr, "[ULP] Missing data for implementation"); // Use mismatch to get the location of the first unequal value auto pair = std::mismatch(refData, refDataEnd, impData, std::next(impData, elementCount), [ulp](const auto& referenceValue, const auto& implementationValue) { diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index 9aa6ba2..3bdc99f 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -202,7 +202,8 @@ static_assert(std::numeric_limits::is_iec559, "TOSA Reference Model has not been built with standard IEEE 754 64-bit float support; Bounds based " "verification is invalid"); -bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound) +template +bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorBound) { // Both must be NaNs to be correct if (std::isnan(referenceValue) || std::isnan(testValue)) @@ -236,8 +237,8 @@ bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBou { // We already canonicalized the input such that the reference value is positive // so no need to check again here. - referenceMin = std::numeric_limits::infinity(); - referenceMax = std::numeric_limits::infinity(); + referenceMin = std::numeric_limits::infinity(); + referenceMax = std::numeric_limits::infinity(); } else if (referenceValue == 0) { @@ -253,23 +254,23 @@ bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBou referenceMin = referenceValue - errorBound; // Handle the overflow cases. - if (referenceMax > AccPrecision::normal_max) + if (referenceMax > AccPrecision::normal_max) { - referenceMax = std::numeric_limits::infinity(); + referenceMax = std::numeric_limits::infinity(); } - if (referenceMin > AccPrecision::normal_max) + if (referenceMin > AccPrecision::normal_max) { - referenceMin = std::numeric_limits::infinity(); + referenceMin = std::numeric_limits::infinity(); } // And the underflow cases. - if (referenceMax < AccPrecision::normal_min) + if (referenceMax < AccPrecision::normal_min) { - referenceMax = AccPrecision::normal_min; + referenceMax = AccPrecision::normal_min; } - if (referenceMin < AccPrecision::normal_min) + if (referenceMin < AccPrecision::normal_min) { referenceMin = 0.0; } @@ -286,4 +287,8 @@ bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBou } return withinBound; } + +// Instantiate the needed check functions +template bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound); +template bool tosaCheckFloatBound(half_float::half testValue, double referenceValue, double errorBound); } // namespace TosaReference diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h index a58950e..45daeac 100644 --- a/reference_model/src/verify/verify_utils.h +++ b/reference_model/src/verify/verify_utils.h @@ -17,6 +17,7 @@ #define VERIFY_UTILS_H_ #include "dtype.h" +#include "half.hpp" #include "types.h" #include @@ -135,10 +136,17 @@ struct AccPrecision static constexpr double normal_max = const_exp2(128) - const_exp2(127 - 23); static constexpr int32_t normal_frac = 23; }; +template <> +struct AccPrecision +{ + static constexpr double normal_min = const_exp2(-14); + static constexpr double normal_max = const_exp2(16) - const_exp2(15 - 10); + static constexpr int32_t normal_frac = 7; +}; /// \brief Error bounds check for ULP and ABS_ERROR modes -bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound); - +template +bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorBound); }; // namespace TosaReference #endif // VERIFY_UTILS_H_ -- cgit v1.2.1