diff options
Diffstat (limited to 'reference_model/src/verify/verify_dot_product.cc')
-rw-r--r-- | reference_model/src/verify/verify_dot_product.cc | 40 |
1 files changed, 26 insertions, 14 deletions
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc index a036cba..ea50573 100644 --- a/reference_model/src/verify/verify_dot_product.cc +++ b/reference_model/src/verify/verify_dot_product.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "half.hpp" #include "verifiers.h" +#include <cfloat> #include <cmath> #include <numeric> #include <optional> @@ -43,7 +44,8 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT is_valid = (ref == 0.0) && (imp == 0.0); if (!is_valid) { - WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp); + WARNING("[Verifier][DP] index %d: bound is zero, but ref (%.*g) or imp (%.*g) is not.", index, DBL_DIG, ref, + FLT_DIG, imp); } err = 0.0; } @@ -57,7 +59,7 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT is_valid = std::abs(err) <= KS; if (!is_valid) { - WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS); + WARNING("[Verifier][DP] index %d: out_err (abs(%.*g)) is not within KS (%d).", index, DBL_DIG, err, KS); } } @@ -66,8 +68,15 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT // Generic data validation function template <typename AccType> -bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg) +bool validateData(const double* ref, + const double* bnd, + const AccType* imp, + const std::vector<int32_t>& shape, + const DotProductVerifyInfo& cfg) { + const size_t T = static_cast<size_t>(numElements(shape)); + TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor"); + const int32_t S = cfg.s; // NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias // tensor is non-zero @@ -79,7 +88,12 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size for (size_t i = 0; i < T; ++i) { auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS); - TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range"); + if (!out_err) + { + auto pos = indexToPosition(i, shape); + TOSA_REF_REQUIRE(out_err, "[DP] Location %s: Data required to be zero or error within range", + positionToString(pos).c_str()); + } out_err_sum += out_err.value(); out_err_sumsq += out_err.value() * out_err.value(); } @@ -88,13 +102,13 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size { const double max_bias = 2 * sqrt(KS * T); // Check error bias magnitude for data sets S which are not positive biased - TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%g)) is out of range (%g)", - out_err_sum, max_bias); + TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%.*g)) is out of range (%.*g)", + DBL_DIG, out_err_sum, DBL_DIG, max_bias); } // Check error variance magnitude const double max_error = 0.4 * KS * T; - TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)", - out_err_sumsq, max_error); + TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%.*g) is out of range (%.*g)", DBL_DIG, + out_err_sumsq, DBL_DIG, max_error); return true; } } // namespace @@ -106,9 +120,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* TOSA_REF_REQUIRE(refBnd != nullptr, "[DP] Reference bounds tensor is missing"); TOSA_REF_REQUIRE(imp != nullptr, "[DP] Implementation tensor is missing"); - // Get number of dot-product elements - const int64_t T = numElements(std::vector<int32_t>(ref->shape, ref->shape + ref->num_dims)); - TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor"); + const std::vector<int32_t> refShape(ref->shape, ref->shape + ref->num_dims); const double* refData = reinterpret_cast<const double*>(ref->data); const double* refBndData = reinterpret_cast<const double*>(refBnd->data); @@ -119,13 +131,13 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* case tosa_datatype_fp32_t: { const float* impData = reinterpret_cast<const float*>(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); - return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo); + return validateData(refData, refBndData, impData, refShape, dpInfo); break; } case tosa_datatype_fp16_t: { const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); - return validateData(refData, refBndData, impData, static_cast<size_t>(T), dpInfo); + return validateData(refData, refBndData, impData, refShape, dpInfo); break; } default: { |