aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_dot_product.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_dot_product.cc')
-rw-r--r--reference_model/src/verify/verify_dot_product.cc40
1 files changed, 26 insertions, 14 deletions
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index a036cba..ea50573 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
#include "half.hpp"
#include "verifiers.h"
+#include <cfloat>
#include <cmath>
#include <numeric>
#include <optional>
@@ -43,7 +44,8 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
is_valid = (ref == 0.0) && (imp == 0.0);
if (!is_valid)
{
- WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp);
+ WARNING("[Verifier][DP] index %d: bound is zero, but ref (%.*g) or imp (%.*g) is not.", index, DBL_DIG, ref,
+ FLT_DIG, imp);
}
err = 0.0;
}
@@ -57,7 +59,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
is_valid = std::abs(err) <= KS;
if (!is_valid)
{
- WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS);
+ WARNING("[Verifier][DP] index %d: out_err (abs(%.*g)) is not within KS (%d).", index, DBL_DIG, err, KS);
}
}
@@ -66,8 +68,15 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
// Generic data validation function
template <typename AccType>
-bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
+bool validateData(const double* ref,
+ const double* bnd,
+ const AccType* imp,
+ const std::vector<int32_t>& shape,
+ const DotProductVerifyInfo& cfg)
{
+ const size_t T = static_cast<size_t>(numElements(shape));
+ TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
+
const int32_t S = cfg.s;
// NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias
// tensor is non-zero
@@ -79,7 +88,12 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
for (size_t i = 0; i < T; ++i)
{
auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS);
- TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range");
+ if (!out_err)
+ {
+ auto pos = indexToPosition(i, shape);
+ TOSA_REF_REQUIRE(out_err, "[DP] Location %s: Data required to be zero or error within range",
+ positionToString(pos).c_str());
+ }
out_err_sum += out_err.value();
out_err_sumsq += out_err.value() * out_err.value();
}
@@ -88,13 +102,13 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
{
const double max_bias = 2 * sqrt(KS * T);
// Check error bias magnitude for data sets S which are not positive biased
- TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%g)) is out of range (%g)",
- out_err_sum, max_bias);
+ TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%.*g)) is out of range (%.*g)",
+ DBL_DIG, out_err_sum, DBL_DIG, max_bias);
}
// Check error variance magnitude
const double max_error = 0.4 * KS * T;
- TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)",
- out_err_sumsq, max_error);
+ TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%.*g) is out of range (%.*g)", DBL_DIG,
+ out_err_sumsq, DBL_DIG, max_error);
return true;
}
} // namespace
@@ -106,9 +120,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
TOSA_REF_REQUIRE(refBnd != nullptr, "[DP] Reference bounds tensor is missing");
TOSA_REF_REQUIRE(imp != nullptr, "[DP] Implementation tensor is missing");
- // Get number of dot-product elements
- const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
- TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
+ const std::vector<int32_t> refShape(ref->shape, ref->shape + ref->num_dims);
const double* refData = reinterpret_cast<const double*>(ref->data);
const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
@@ -119,13 +131,13 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
case tosa_datatype_fp32_t: {
const float* impData = reinterpret_cast<const float*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ return validateData(refData, refBndData, impData, refShape, dpInfo);
break;
}
case tosa_datatype_fp16_t: {
const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ return validateData(refData, refBndData, impData, refShape, dpInfo);
break;
}
default: {