aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-11-30 14:18:19 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2023-12-04 10:02:15 +0000
commit718f347a2d886381de19420b5b5b99db8f2b7338 (patch)
tree87f6ab932029654b4e0704938dbe6ab7135da27d /reference_model/src/verify
parentfe79accba2c220036c7b5ea0aa27bff5ef74ec73 (diff)
downloadreference_model-718f347a2d886381de19420b5b5b99db8f2b7338.tar.gz
Main Compliance FP16 support - generate and verify.
FP16 support for all existing operators for compliance: * DOT_PRODUCT * ULP * EXACT * ABS_ERROR Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I8d25448a793375b53880da3787d8f839767f02cf
Diffstat (limited to 'reference_model/src/verify')
-rw-r--r--reference_model/src/verify/verify_abs_error.cc13
-rw-r--r--reference_model/src/verify/verify_dot_product.cc25
-rw-r--r--reference_model/src/verify/verify_exact.cc20
-rw-r--r--reference_model/src/verify/verify_ulp.cc36
-rw-r--r--reference_model/src/verify/verify_utils.cc25
-rw-r--r--reference_model/src/verify/verify_utils.h12
6 files changed, 95 insertions, 36 deletions
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc
index b43da08..5aaa0ad 100644
--- a/reference_model/src/verify/verify_abs_error.cc
+++ b/reference_model/src/verify/verify_abs_error.cc
@@ -18,6 +18,7 @@
#include <type_traits>
#include <utility>
+#include "half.hpp"
#include "verifiers.h"
namespace TosaReference
@@ -25,14 +26,15 @@ namespace TosaReference
namespace
{
-bool validateData(const double* ref, const double* bnd, const float* imp, const std::vector<int32_t>& shape)
+template <typename OutDtype>
+bool validateData(const double* ref, const double* bnd, const OutDtype* imp, const std::vector<int32_t>& shape)
{
const size_t T = static_cast<size_t>(numElements(shape));
TOSA_REF_REQUIRE(T > 0, "[AE] Invalid shape for reference tensor");
for (size_t i = 0; i < T; ++i)
{
- double errBound = std::abs(ref[i]) * exp2(-AccPrecision<float>::normal_frac) * bnd[i];
+ double errBound = std::abs(ref[i]) * exp2(-AccPrecision<OutDtype>::normal_frac) * bnd[i];
bool valid = tosaCheckFloatBound(imp[i], ref[i], errBound);
if (!valid)
{
@@ -60,7 +62,12 @@ bool verifyAbsError(const CTensor* ref, const CTensor* refBnd, const CTensor* im
switch (imp->data_type)
{
case tosa_datatype_fp32_t: {
- const float* impData = reinterpret_cast<const float*>(imp->data);
+ const auto* impData = reinterpret_cast<const float*>(imp->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation");
+ return validateData(refData, refBndData, impData, refShape);
+ }
+ case tosa_datatype_fp16_t: {
+ const auto* impData = reinterpret_cast<const half_float::half*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[AE] Missing data for implementation");
return validateData(refData, refBndData, impData, refShape);
}
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index 15de427..a036cba 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "func_debug.h"
+#include "half.hpp"
#include "verifiers.h"
#include <cmath>
@@ -25,13 +26,19 @@ namespace TosaReference
namespace
{
// Generic element validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+template <typename AccType>
std::optional<double> validateElement(size_t index, double ref, double bnd, AccType imp, size_t KS)
{
double err = 0.0;
bool is_valid = true;
- if (bnd == 0.0)
+ if (std::isinf(static_cast<AccType>(bnd)))
+ {
+ // dot product can overflow and there is no accuracy limit
+ is_valid = true;
+ err = 0.0;
+ }
+ else if (bnd == 0.0)
{
is_valid = (ref == 0.0) && (imp == 0.0);
if (!is_valid)
@@ -40,12 +47,6 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
}
err = 0.0;
}
- else if (std::isinf(static_cast<AccType>(bnd)))
- {
- // dot product can overflow and there is no accuracy limit
- is_valid = true;
- err = 0.0;
- }
else
{
// 0.0 < bnd < infinity
@@ -64,7 +65,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
}
// Generic data validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+template <typename AccType>
bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
{
const int32_t S = cfg.s;
@@ -121,6 +122,12 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
break;
}
+ case tosa_datatype_fp16_t: {
+ const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
+ return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ break;
+ }
default: {
WARNING("[Verifier][DP] Data-type not supported.");
break;
diff --git a/reference_model/src/verify/verify_exact.cc b/reference_model/src/verify/verify_exact.cc
index 36b4ec9..971df9c 100644
--- a/reference_model/src/verify/verify_exact.cc
+++ b/reference_model/src/verify/verify_exact.cc
@@ -13,12 +13,14 @@
// limitations under the License.
#include "func_debug.h"
+#include "half.hpp"
#include "verifiers.h"
#include <cmath>
namespace
{
-bool exact_fp32(const double& referenceValue, const float& implementationValue)
+template <typename OutDtype>
+bool exact_fp(const double& referenceValue, const OutDtype& implementationValue)
{
return std::isnan(referenceValue) ? std::isnan(implementationValue) : (referenceValue == implementationValue);
}
@@ -38,16 +40,24 @@ bool verifyExact(const CTensor* referenceTensor, const CTensor* implementationTe
numElements(std::vector<int32_t>(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims));
TOSA_REF_REQUIRE(elementCount > 0, "[E] Invalid shape for reference tensor");
+ TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64");
+ const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
+ TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference");
+
switch (implementationTensor->data_type)
{
case tosa_datatype_fp32_t: {
- TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64");
- const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
- TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference");
const auto* impData = reinterpret_cast<const float*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation");
auto result = std::equal(refData, std::next(refData, elementCount), impData,
- std::next(impData, elementCount), exact_fp32);
+ std::next(impData, elementCount), exact_fp<float>);
+ return result;
+ }
+ case tosa_datatype_fp16_t: {
+ const auto* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation");
+ auto result = std::equal(refData, std::next(refData, elementCount), impData,
+ std::next(impData, elementCount), exact_fp<half_float::half>);
return result;
}
default:
diff --git a/reference_model/src/verify/verify_ulp.cc b/reference_model/src/verify/verify_ulp.cc
index 6e78b96..1b38fe6 100644
--- a/reference_model/src/verify/verify_ulp.cc
+++ b/reference_model/src/verify/verify_ulp.cc
@@ -18,6 +18,7 @@
#include <type_traits>
#include <utility>
+#include "half.hpp"
#include "verifiers.h"
namespace TosaReference
@@ -25,7 +26,8 @@ namespace TosaReference
namespace
{
-bool tosaCheckULP(float testValue, double referenceValue, double ulpNum)
+template <typename OutType>
+bool tosaCheckULP(OutType testValue, double referenceValue, double ulpNum)
{
double errorBound = 0.0;
if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0)
@@ -35,10 +37,10 @@ bool tosaCheckULP(float testValue, double referenceValue, double ulpNum)
// Work out the values magnitude - by raising 2 to the power of the
// exponent and taking the normalized minimum for denormal values
- const double referencePower2 = std::max(exp2(referenceExponent), AccPrecision<float>::normal_min);
+ const double referencePower2 = std::max(exp2(referenceExponent), AccPrecision<OutType>::normal_min);
// Get the value of changing the last bit - by shifting the least significant bit to this magnitude
// i.e. the ULP.
- double ulpValue = referencePower2 * exp2(-AccPrecision<float>::normal_frac);
+ double ulpValue = referencePower2 * exp2(-AccPrecision<OutType>::normal_frac);
errorBound = ulpValue * ulpNum;
}
@@ -57,15 +59,35 @@ bool verifyULP(const CTensor* referenceTensor, const CTensor* implementationTens
const auto elementCount = numElements(refShape);
TOSA_REF_REQUIRE(elementCount > 0, "[ULP] Invalid shape for reference tensor");
- const double ulp = ulpInfo.ulp;
+ const double ulp = ulpInfo.ulp;
+ const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
+ TOSA_REF_REQUIRE(refData != nullptr, "[ULP] Missing data for reference");
+ const auto* refDataEnd = std::next(refData, elementCount);
switch (implementationTensor->data_type)
{
case tosa_datatype_fp32_t: {
- const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
- TOSA_REF_REQUIRE(refData != nullptr, "[ULP] Missing data for reference");
const auto* impData = reinterpret_cast<const float*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[ULP] Missing data for implementation");
- const auto* refDataEnd = std::next(refData, elementCount);
+ // Use mismatch to get the location of the first unequal value
+ auto pair = std::mismatch(refData, refDataEnd, impData, std::next(impData, elementCount),
+ [ulp](const auto& referenceValue, const auto& implementationValue) {
+ return tosaCheckULP(implementationValue, referenceValue, ulp);
+ });
+ if (std::get<0>(pair) == refDataEnd)
+ {
+ // No mismatch found
+ return true;
+ }
+ else
+ {
+ auto pos = indexToPosition(std::get<0>(pair) - refData, refShape);
+ WARNING("[Verfier][ULP] Location %s", positionToString(pos).c_str());
+ return false;
+ }
+ }
+ case tosa_datatype_fp16_t: {
+ const auto* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "[ULP] Missing data for implementation");
// Use mismatch to get the location of the first unequal value
auto pair = std::mismatch(refData, refDataEnd, impData, std::next(impData, elementCount),
[ulp](const auto& referenceValue, const auto& implementationValue) {
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 9aa6ba2..3bdc99f 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -202,7 +202,8 @@ static_assert(std::numeric_limits<double>::is_iec559,
"TOSA Reference Model has not been built with standard IEEE 754 64-bit float support; Bounds based "
"verification is invalid");
-bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound)
+template <typename OutType>
+bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorBound)
{
// Both must be NaNs to be correct
if (std::isnan(referenceValue) || std::isnan(testValue))
@@ -236,8 +237,8 @@ bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBou
{
// We already canonicalized the input such that the reference value is positive
// so no need to check again here.
- referenceMin = std::numeric_limits<float>::infinity();
- referenceMax = std::numeric_limits<float>::infinity();
+ referenceMin = std::numeric_limits<OutType>::infinity();
+ referenceMax = std::numeric_limits<OutType>::infinity();
}
else if (referenceValue == 0)
{
@@ -253,23 +254,23 @@ bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBou
referenceMin = referenceValue - errorBound;
// Handle the overflow cases.
- if (referenceMax > AccPrecision<float>::normal_max)
+ if (referenceMax > AccPrecision<OutType>::normal_max)
{
- referenceMax = std::numeric_limits<float>::infinity();
+ referenceMax = std::numeric_limits<OutType>::infinity();
}
- if (referenceMin > AccPrecision<float>::normal_max)
+ if (referenceMin > AccPrecision<OutType>::normal_max)
{
- referenceMin = std::numeric_limits<float>::infinity();
+ referenceMin = std::numeric_limits<OutType>::infinity();
}
// And the underflow cases.
- if (referenceMax < AccPrecision<float>::normal_min)
+ if (referenceMax < AccPrecision<OutType>::normal_min)
{
- referenceMax = AccPrecision<float>::normal_min;
+ referenceMax = AccPrecision<OutType>::normal_min;
}
- if (referenceMin < AccPrecision<float>::normal_min)
+ if (referenceMin < AccPrecision<OutType>::normal_min)
{
referenceMin = 0.0;
}
@@ -286,4 +287,8 @@ bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBou
}
return withinBound;
}
+
+// Instantiate the needed check functions
+template bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound);
+template bool tosaCheckFloatBound(half_float::half testValue, double referenceValue, double errorBound);
} // namespace TosaReference
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
index a58950e..45daeac 100644
--- a/reference_model/src/verify/verify_utils.h
+++ b/reference_model/src/verify/verify_utils.h
@@ -17,6 +17,7 @@
#define VERIFY_UTILS_H_
#include "dtype.h"
+#include "half.hpp"
#include "types.h"
#include <cstdint>
@@ -135,10 +136,17 @@ struct AccPrecision<float>
static constexpr double normal_max = const_exp2(128) - const_exp2(127 - 23);
static constexpr int32_t normal_frac = 23;
};
+template <>
+struct AccPrecision<half_float::half>
+{
+ static constexpr double normal_min = const_exp2(-14);
+ static constexpr double normal_max = const_exp2(16) - const_exp2(15 - 10);
+ static constexpr int32_t normal_frac = 7;
+};
/// \brief Error bounds check for ULP and ABS_ERROR modes
-bool tosaCheckFloatBound(float testValue, double referenceValue, double errorBound);
-
+template <typename OutType>
+bool tosaCheckFloatBound(OutType testValue, double referenceValue, double errorBound);
}; // namespace TosaReference
#endif // VERIFY_UTILS_H_