aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_dot_product.cc
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-18 17:22:21 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-02 23:22:09 +0000
commitd1a08ce27ef8d0f6cf77e1b864610aade06edc5c (patch)
tree777992f45d240361f898b1d21902c2a46c58235f /reference_model/src/verify/verify_dot_product.cc
parentb0b9e33c3500bd8dc9b12ef012d4234b1245247a (diff)
downloadreference_model-d1a08ce27ef8d0f6cf77e1b864610aade06edc5c.tar.gz
Compliance mode testing for CONV2D
Added CONV2D data generation. Updated verify dot product check to latest specification. Updated test generator and python datagenerator library to create const files during test generation. Add support for compliance test sets to conformance test_select. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I5be3b761a1e3ef259c058e493877cd5a89d5778b
Diffstat (limited to 'reference_model/src/verify/verify_dot_product.cc')
-rw-r--r--reference_model/src/verify/verify_dot_product.cc52
1 files changed, 27 insertions, 25 deletions
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index 2a1d273..233c072 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -14,6 +14,7 @@
#include "func_debug.h"
#include "verifiers.h"
+#include "verify_utils.h"
#include <cmath>
#include <numeric>
@@ -24,22 +25,9 @@ namespace TosaReference
{
namespace
{
-
-// Accumulator precision
-template <typename T>
-struct AccPrecision;
-#define two_m42 1.0 / (double)(((int64_t)1) << 42) // 2^-42
-template <>
-struct AccPrecision<float>
-{
- static constexpr double precision = (double)(1 << 24);
- static constexpr double min_normal = two_m42 * two_m42 * two_m42; // 2^-126
-};
-#undef two_m42
-
// Generic element validation function
template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<AccType>, int> = 0>
-std::optional<double> validateElement(double ref, double bnd, AccType imp, size_t KS)
+std::optional<double> validateElement(size_t index, double ref, double bnd, AccType imp, size_t KS)
{
double err = 0.0;
bool is_valid = true;
@@ -47,7 +35,11 @@ std::optional<double> validateElement(double ref, double bnd, AccType imp, size_
if (bnd == 0.0)
{
is_valid = (ref == 0.0) && (imp == 0.0);
- err = 0.0;
+ if (!is_valid)
+ {
+ WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp);
+ }
+ err = 0.0;
}
else if (std::isinf(static_cast<AccType>(bnd)))
{
@@ -58,11 +50,15 @@ std::optional<double> validateElement(double ref, double bnd, AccType imp, size_
else
{
// 0.0 < bnd < infinity
- const double bnd_norm = std::max(bnd, AccPrecision<AccType>::min_normal);
- const double imp_fp64 = static_cast<double>(imp);
- const double acc_prec_fp64 = AccPrecision<AccType>::precision;
- err = (imp_fp64 - ref) * acc_prec_fp64 / bnd_norm;
- is_valid = std::abs(err) <= KS;
+ const double out_err_bnd =
+ std::max(bnd * exp2(-1 - AccPrecision<AccType>::normal_frac), AccPrecision<AccType>::normal_min);
+ const double imp_fp64 = static_cast<double>(imp);
+ err = (imp_fp64 - ref) / out_err_bnd;
+ is_valid = std::abs(err) <= KS;
+ if (!is_valid)
+ {
+ WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS);
+ }
}
return is_valid ? std::optional(err) : std::nullopt;
@@ -73,7 +69,8 @@ template <typename AccType, typename std::enable_if_t<std::is_floating_point_v<A
bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
{
const int32_t S = cfg.s;
- // TODO - needed for other ops - (max_value(bias_abs) > 0) ? (KS + 1) : KS
+ // NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias
+ // tensor is non-zero
const int32_t KS = cfg.ks;
double out_err_sum = 0.0;
@@ -81,7 +78,7 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
for (size_t i = 0; i < T; ++i)
{
- auto out_err = validateElement<AccType>(ref[i], bnd[i], imp[i], KS);
+ auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS);
TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range");
out_err_sum += out_err.value();
out_err_sumsq += out_err.value() * out_err.value();
@@ -89,11 +86,16 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
if (S >= 3 && S <= 5)
{
+ const double max_bias = 2 * sqrt(KS * T);
+ out_err_sum = std::abs(out_err_sum);
// Check error bias magnitude for data sets S which are not positive biased
- TOSA_REF_REQUIRE(std::abs(out_err_sum) <= 2 * sqrt(KS * T), "[DP] Bias magnitude is out of range");
+ TOSA_REF_REQUIRE(out_err_sum <= max_bias, "[DP] Bias magnitude (%g) is out of range (%g)", out_err_sum,
+ max_bias);
}
// Check error variance magnitude
- TOSA_REF_REQUIRE(out_err_sumsq <= 0.4 * KS * T, "[DP] Error variance magnitude is out of range");
+ const double max_error = 0.4 * KS * T;
+ TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)",
+ out_err_sumsq, max_error);
return true;
}
} // namespace
@@ -107,7 +109,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
// Get number of dot-product elements
const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
- TOSA_REF_REQUIRE(T > 0, "invalid shape for reference tensor");
+ TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
const double* refData = reinterpret_cast<const double*>(ref->data);
const double* refBndData = reinterpret_cast<const double*>(refBnd->data);