diff options
Diffstat (limited to 'reference_model/src/verify/verify_exact.cc')
-rw-r--r-- | reference_model/src/verify/verify_exact.cc | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/reference_model/src/verify/verify_exact.cc b/reference_model/src/verify/verify_exact.cc index 36b4ec9..971df9c 100644 --- a/reference_model/src/verify/verify_exact.cc +++ b/reference_model/src/verify/verify_exact.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "func_debug.h" +#include "half.hpp" #include "verifiers.h" #include <cmath> namespace { -bool exact_fp32(const double& referenceValue, const float& implementationValue) +template <typename OutDtype> +bool exact_fp(const double& referenceValue, const OutDtype& implementationValue) { return std::isnan(referenceValue) ? std::isnan(implementationValue) : (referenceValue == implementationValue); } @@ -38,16 +40,24 @@ bool verifyExact(const CTensor* referenceTensor, const CTensor* implementationTe numElements(std::vector<int32_t>(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims)); TOSA_REF_REQUIRE(elementCount > 0, "[E] Invalid shape for reference tensor"); + TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64"); + const auto* refData = reinterpret_cast<const double*>(referenceTensor->data); + TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference"); + switch (implementationTensor->data_type) { case tosa_datatype_fp32_t: { - TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64"); - const auto* refData = reinterpret_cast<const double*>(referenceTensor->data); - TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference"); const auto* impData = reinterpret_cast<const float*>(implementationTensor->data); TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation"); auto result = std::equal(refData, std::next(refData, elementCount), impData, - std::next(impData, elementCount), exact_fp32); + std::next(impData, elementCount), exact_fp<float>); + return result; + } + case tosa_datatype_fp16_t: { + const auto* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data); + TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation"); + auto result = std::equal(refData, std::next(refData, elementCount), impData, + std::next(impData, elementCount), exact_fp<half_float::half>); return result; } default: |