aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_dot_product.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_dot_product.cc')
-rw-r--r--reference_model/src/verify/verify_dot_product.cc25
1 files changed, 16 insertions, 9 deletions
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index 15de427..a036cba 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "func_debug.h"
+#include "half.hpp"
#include "verifiers.h"
#include <cmath>
@@ -25,13 +26,19 @@ namespace TosaReference
namespace
{
// Generic element validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+template <typename AccType>
std::optional<double> validateElement(size_t index, double ref, double bnd, AccType imp, size_t KS)
{
double err = 0.0;
bool is_valid = true;
- if (bnd == 0.0)
+ if (std::isinf(static_cast<AccType>(bnd)))
+ {
+ // dot product can overflow and there is no accuracy limit
+ is_valid = true;
+ err = 0.0;
+ }
+ else if (bnd == 0.0)
{
is_valid = (ref == 0.0) && (imp == 0.0);
if (!is_valid)
@@ -40,12 +47,6 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
}
err = 0.0;
}
- else if (std::isinf(static_cast<AccType>(bnd)))
- {
- // dot product can overflow and there is no accuracy limit
- is_valid = true;
- err = 0.0;
- }
else
{
// 0.0 < bnd < infinity
@@ -64,7 +65,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
}
// Generic data validation function
-template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
+template <typename AccType>
bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
{
const int32_t S = cfg.s;
@@ -121,6 +122,12 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
break;
}
+ case tosa_datatype_fp16_t: {
+ const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
+ return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ break;
+ }
default: {
WARNING("[Verifier][DP] Data-type not supported.");
break;