aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_dot_product.cc
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-18 16:57:28 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-02-07 10:57:40 +0000
commitc8330811352f753e36f2ee7be4c7d0e6002f21e7 (patch)
tree967eeb59876e7c6abea26ff2e892d5ff94134992 /reference_model/src/verify/verify_dot_product.cc
parent9847722e2b172b69fe9ae80a05c27ca5c8c36617 (diff)
downloadreference_model-c8330811352f753e36f2ee7be4c7d0e6002f21e7.tar.gz
Main Compliance: FFT2D support
Improve access to DOT_PRODUCT generator index and location for debugging. Enable multiple result files for compliance and improve output. Fix up precise and abs modes for FFT2D in ref model to produce correct results and bounds using abs weights. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ide0c9f9f80397e5f1e07ca30a1036d6014b5784d
Diffstat (limited to 'reference_model/src/verify/verify_dot_product.cc')
-rw-r--r--reference_model/src/verify/verify_dot_product.cc40
1 files changed, 26 insertions, 14 deletions
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index a036cba..ea50573 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
#include "half.hpp"
#include "verifiers.h"
+#include <cfloat>
#include <cmath>
#include <numeric>
#include <optional>
@@ -43,7 +44,8 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
is_valid = (ref == 0.0) && (imp == 0.0);
if (!is_valid)
{
- WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp);
+ WARNING("[Verifier][DP] index %d: bound is zero, but ref (%.*g) or imp (%.*g) is not.", index, DBL_DIG, ref,
+ FLT_DIG, imp);
}
err = 0.0;
}
@@ -57,7 +59,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
is_valid = std::abs(err) <= KS;
if (!is_valid)
{
- WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS);
+ WARNING("[Verifier][DP] index %d: out_err (abs(%.*g)) is not within KS (%d).", index, DBL_DIG, err, KS);
}
}
@@ -66,8 +68,15 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
// Generic data validation function
template <typename AccType>
-bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg)
+bool validateData(const double* ref,
+ const double* bnd,
+ const AccType* imp,
+ const std::vector<int32_t>& shape,
+ const DotProductVerifyInfo& cfg)
{
+ const size_t T = static_cast<size_t>(numElements(shape));
+ TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
+
const int32_t S = cfg.s;
// NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias
// tensor is non-zero
@@ -79,7 +88,12 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
for (size_t i = 0; i < T; ++i)
{
auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS);
- TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range");
+ if (!out_err)
+ {
+ auto pos = indexToPosition(i, shape);
+ TOSA_REF_REQUIRE(out_err, "[DP] Location %s: Data required to be zero or error within range",
+ positionToString(pos).c_str());
+ }
out_err_sum += out_err.value();
out_err_sumsq += out_err.value() * out_err.value();
}
@@ -88,13 +102,13 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size
{
const double max_bias = 2 * sqrt(KS * T);
// Check error bias magnitude for data sets S which are not positive biased
- TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%g)) is out of range (%g)",
- out_err_sum, max_bias);
+ TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%.*g)) is out of range (%.*g)",
+ DBL_DIG, out_err_sum, DBL_DIG, max_bias);
}
// Check error variance magnitude
const double max_error = 0.4 * KS * T;
- TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)",
- out_err_sumsq, max_error);
+ TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%.*g) is out of range (%.*g)", DBL_DIG,
+ out_err_sumsq, DBL_DIG, max_error);
return true;
}
} // namespace
@@ -106,9 +120,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
TOSA_REF_REQUIRE(refBnd != nullptr, "[DP] Reference bounds tensor is missing");
TOSA_REF_REQUIRE(imp != nullptr, "[DP] Implementation tensor is missing");
- // Get number of dot-product elements
- const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims));
- TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
+ const std::vector<int32_t> refShape(ref->shape, ref->shape + ref->num_dims);
const double* refData = reinterpret_cast<const double*>(ref->data);
const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
@@ -119,13 +131,13 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor*
case tosa_datatype_fp32_t: {
const float* impData = reinterpret_cast<const float*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ return validateData(refData, refBndData, impData, refShape, dpInfo);
break;
}
case tosa_datatype_fp16_t: {
const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo);
+ return validateData(refData, refBndData, impData, refShape, dpInfo);
break;
}
default: {