aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_reduce_product.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_reduce_product.cc')
-rw-r--r--reference_model/src/verify/verify_reduce_product.cc33
1 files changed, 10 insertions, 23 deletions
diff --git a/reference_model/src/verify/verify_reduce_product.cc b/reference_model/src/verify/verify_reduce_product.cc
index 0e58892..a8aaa53 100644
--- a/reference_model/src/verify/verify_reduce_product.cc
+++ b/reference_model/src/verify/verify_reduce_product.cc
@@ -21,31 +21,15 @@
namespace TosaReference
{
-
namespace
{
-template <typename OutDtype>
-bool validateData(const double* ref,
- const OutDtype* imp,
- const std::vector<int32_t>& shape,
- const ReduceProductVerifyInfo& cfg)
+template <typename OutType>
+double calcErrorBound(double referenceValue, double boundsValue, const void* cfgPtr)
{
- const size_t T = static_cast<size_t>(numElements(shape));
- TOSA_REF_REQUIRE(T > 0, "[RP] Invalid shape for reference tensor");
+ const auto cfg = reinterpret_cast<const ReduceProductVerifyInfo*>(cfgPtr);
+ unused(boundsValue);
- for (size_t i = 0; i < T; ++i)
- {
- double errBound =
- std::abs(ref[i]) * (std::pow(1 + std::pow(2, -AccPrecision<OutDtype>::normal_frac - 1), cfg.n) - 1);
- bool valid = tosaCheckFloatBound(imp[i], ref[i], errBound);
- if (!valid)
- {
- auto pos = indexToPosition(i, shape);
- WARNING("[Verifier][RP] Location %s", positionToString(pos).c_str());
- return false;
- }
- }
- return true;
+ return std::abs(referenceValue) * (std::pow(1 + std::pow(2, -AccPrecision<OutType>::normal_frac - 1), cfg->n) - 1);
}
} // namespace
@@ -62,17 +46,20 @@ bool verifyReduceProduct(const CTensor* referenceTensor,
const double* refData = reinterpret_cast<const double*>(referenceTensor->data);
TOSA_REF_REQUIRE(refData != nullptr, "[RP] Missing data for reference");
+ const std::string modeStr = "RP";
+
switch (implementationTensor->data_type)
{
case tosa_datatype_fp32_t: {
const auto* impData = reinterpret_cast<const float*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[RP] Missing data for implementation");
- return validateData(refData, impData, refShape, rpInfo);
+ return validateData(refData, nullptr, impData, refShape, modeStr, &rpInfo, &calcErrorBound<float>);
}
case tosa_datatype_fp16_t: {
const auto* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[RP] Missing data for implementation");
- return validateData(refData, impData, refShape, rpInfo);
+ return validateData(refData, nullptr, impData, refShape, modeStr, &rpInfo,
+ &calcErrorBound<half_float::half>);
}
default:
WARNING("[Verifier][RP] Data-type not supported.");