From c8330811352f753e36f2ee7be4c7d0e6002f21e7 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 18 Jan 2024 16:57:28 +0000 Subject: Main Compliance: FFT2D support Improve access to DOT_PRODUCT generator index and location for debugging. Enable multiple result files for compliance and improve output. Fix up precise and abs modes for FFT2D in ref model to produce correct results and bounds using abs weights. Signed-off-by: Jeremy Johnson Change-Id: Ide0c9f9f80397e5f1e07ca30a1036d6014b5784d --- reference_model/src/verify/verify_dot_product.cc | 40 +++++++++++++++--------- 1 file changed, 26 insertions(+), 14 deletions(-) (limited to 'reference_model/src/verify/verify_dot_product.cc') diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc index a036cba..ea50573 100644 --- a/reference_model/src/verify/verify_dot_product.cc +++ b/reference_model/src/verify/verify_dot_product.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "half.hpp" #include "verifiers.h" +#include #include #include #include @@ -43,7 +44,8 @@ std::optional validateElement(size_t index, double ref, double bnd, AccT is_valid = (ref == 0.0) && (imp == 0.0); if (!is_valid) { - WARNING("[Verifier][DP] index %d - bound is zero, but ref (%g) or imp (%f) is not.", index, ref, imp); + WARNING("[Verifier][DP] index %d: bound is zero, but ref (%.*g) or imp (%.*g) is not.", index, DBL_DIG, ref, + FLT_DIG, imp); } err = 0.0; } @@ -57,7 +59,7 @@ std::optional validateElement(size_t index, double ref, double bnd, AccT is_valid = std::abs(err) <= KS; if (!is_valid) { - WARNING("[Verifier][DP] index %d - out_err (%g) is not within KS (%d).", index, err, KS); + WARNING("[Verifier][DP] index %d: out_err (abs(%.*g)) is not within KS (%d).", index, DBL_DIG, err, KS); } } @@ -66,8 +68,15 @@ std::optional validateElement(size_t index, double ref, double bnd, AccT // Generic data validation function template -bool validateData(const double* ref, const double* bnd, const AccType* imp, size_t T, const DotProductVerifyInfo& cfg) +bool validateData(const double* ref, + const double* bnd, + const AccType* imp, + const std::vector& shape, + const DotProductVerifyInfo& cfg) { + const size_t T = static_cast(numElements(shape)); + TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor"); + const int32_t S = cfg.s; // NOTE: KS in the compliance config MUST have already been updated to (KS + 1) if the bias // tensor is non-zero @@ -79,7 +88,12 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size for (size_t i = 0; i < T; ++i) { auto out_err = validateElement(i, ref[i], bnd[i], imp[i], KS); - TOSA_REF_REQUIRE(out_err, "[DP] Data required to be zero or error within range"); + if (!out_err) + { + auto pos = indexToPosition(i, shape); + TOSA_REF_REQUIRE(out_err, "[DP] Location %s: Data required to be zero or error within range", + positionToString(pos).c_str()); + } out_err_sum += out_err.value(); out_err_sumsq += out_err.value() * out_err.value(); } @@ -88,13 +102,13 @@ bool validateData(const double* ref, const double* bnd, const AccType* imp, size { const double max_bias = 2 * sqrt(KS * T); // Check error bias magnitude for data sets S which are not positive biased - TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%g)) is out of range (%g)", - out_err_sum, max_bias); + TOSA_REF_REQUIRE(std::abs(out_err_sum) <= max_bias, "[DP] Bias magnitude (abs(%.*g)) is out of range (%.*g)", + DBL_DIG, out_err_sum, DBL_DIG, max_bias); } // Check error variance magnitude const double max_error = 0.4 * KS * T; - TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%g) is out of range (%g)", - out_err_sumsq, max_error); + TOSA_REF_REQUIRE(out_err_sumsq <= max_error, "[DP] Error variance magnitude (%.*g) is out of range (%.*g)", DBL_DIG, + out_err_sumsq, DBL_DIG, max_error); return true; } } // namespace @@ -106,9 +120,7 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* TOSA_REF_REQUIRE(refBnd != nullptr, "[DP] Reference bounds tensor is missing"); TOSA_REF_REQUIRE(imp != nullptr, "[DP] Implementation tensor is missing"); - // Get number of dot-product elements - const int64_t T = numElements(std::vector(ref->shape, ref->shape + ref->num_dims)); - TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor"); + const std::vector refShape(ref->shape, ref->shape + ref->num_dims); const double* refData = reinterpret_cast(ref->data); const double* refBndData = reinterpret_cast(refBnd->data); @@ -119,13 +131,13 @@ bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* case tosa_datatype_fp32_t: { const float* impData = reinterpret_cast(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); - return validateData(refData, refBndData, impData, static_cast(T), dpInfo); + return validateData(refData, refBndData, impData, refShape, dpInfo); break; } case tosa_datatype_fp16_t: { const half_float::half* impData = reinterpret_cast(imp->data); TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation"); - return validateData(refData, refBndData, impData, static_cast(T), dpInfo); + return validateData(refData, refBndData, impData, refShape, dpInfo); break; } default: { -- cgit v1.2.1