aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_dot_product.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_dot_product.cc')
-rw-r--r--reference_model/src/verify/verify_dot_product.cc41
1 files changed, 22 insertions, 19 deletions
diff --git a/reference_model/src/verify/verify_dot_product.cc b/reference_model/src/verify/verify_dot_product.cc
index ea50573..3f82c1e 100644
--- a/reference_model/src/verify/verify_dot_product.cc
+++ b/reference_model/src/verify/verify_dot_product.cc
@@ -66,13 +66,13 @@ std::optional<double> validateElement(size_t index, double ref, double bnd, AccT
return is_valid ? std::optional(err) : std::nullopt;
}
-// Generic data validation function
+// Dot Product data validation function
template <typename AccType>
-bool validateData(const double* ref,
- const double* bnd,
- const AccType* imp,
- const std::vector<int32_t>& shape,
- const DotProductVerifyInfo& cfg)
+bool validateDataDP(const double* referenceData,
+ const double* boundsData,
+ const AccType* implementationData,
+ const std::vector<int32_t>& shape,
+ const DotProductVerifyInfo& cfg)
{
const size_t T = static_cast<size_t>(numElements(shape));
TOSA_REF_REQUIRE(T > 0, "[DP] Invalid shape for reference tensor");
@@ -87,7 +87,7 @@ bool validateData(const double* ref,
for (size_t i = 0; i < T; ++i)
{
- auto out_err = validateElement<AccType>(i, ref[i], bnd[i], imp[i], KS);
+ auto out_err = validateElement<AccType>(i, referenceData[i], boundsData[i], implementationData[i], KS);
if (!out_err)
{
auto pos = indexToPosition(i, shape);
@@ -113,31 +113,34 @@ bool validateData(const double* ref,
}
} // namespace
-bool verifyDotProduct(const CTensor* ref, const CTensor* refBnd, const CTensor* imp, const DotProductVerifyInfo& dpInfo)
+bool verifyDotProduct(const CTensor* referenceTensor,
+ const CTensor* boundsTensor,
+ const CTensor* implementationTensor,
+ const DotProductVerifyInfo& dpInfo)
{
// Validate that tensors are provided
- TOSA_REF_REQUIRE(ref != nullptr, "[DP] Reference tensor is missing");
- TOSA_REF_REQUIRE(refBnd != nullptr, "[DP] Reference bounds tensor is missing");
- TOSA_REF_REQUIRE(imp != nullptr, "[DP] Implementation tensor is missing");
+ TOSA_REF_REQUIRE(referenceTensor != nullptr, "[DP] Reference tensor is missing");
+ TOSA_REF_REQUIRE(boundsTensor != nullptr, "[DP] Reference bounds tensor is missing");
+ TOSA_REF_REQUIRE(implementationTensor != nullptr, "[DP] Implementation tensor is missing");
- const std::vector<int32_t> refShape(ref->shape, ref->shape + ref->num_dims);
+ const std::vector<int32_t> refShape(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims);
- const double* refData = reinterpret_cast<const double*>(ref->data);
- const double* refBndData = reinterpret_cast<const double*>(refBnd->data);
+ const double* refData = reinterpret_cast<const double*>(referenceTensor->data);
+ const double* refBndData = reinterpret_cast<const double*>(boundsTensor->data);
TOSA_REF_REQUIRE(refData != nullptr && refBndData != nullptr, "[DP] Missing data for reference or bounds tensors");
- switch (imp->data_type)
+ switch (implementationTensor->data_type)
{
case tosa_datatype_fp32_t: {
- const float* impData = reinterpret_cast<const float*>(imp->data);
+ const float* impData = reinterpret_cast<const float*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, refShape, dpInfo);
+ return validateDataDP(refData, refBndData, impData, refShape, dpInfo);
break;
}
case tosa_datatype_fp16_t: {
- const half_float::half* impData = reinterpret_cast<const half_float::half*>(imp->data);
+ const half_float::half* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[DP] Missing data for implementation");
- return validateData(refData, refBndData, impData, refShape, dpInfo);
+ return validateDataDP(refData, refBndData, impData, refShape, dpInfo);
break;
}
default: {