aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/verify/verify_exact.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/verify/verify_exact.cc')
-rw-r--r--reference_model/src/verify/verify_exact.cc20
1 files changed, 15 insertions, 5 deletions
diff --git a/reference_model/src/verify/verify_exact.cc b/reference_model/src/verify/verify_exact.cc
index 36b4ec9..971df9c 100644
--- a/reference_model/src/verify/verify_exact.cc
+++ b/reference_model/src/verify/verify_exact.cc
@@ -13,12 +13,14 @@
// limitations under the License.
#include "func_debug.h"
+#include "half.hpp"
#include "verifiers.h"
#include <cmath>
namespace
{
-bool exact_fp32(const double& referenceValue, const float& implementationValue)
+template <typename OutDtype>
+bool exact_fp(const double& referenceValue, const OutDtype& implementationValue)
{
return std::isnan(referenceValue) ? std::isnan(implementationValue) : (referenceValue == implementationValue);
}
@@ -38,16 +40,24 @@ bool verifyExact(const CTensor* referenceTensor, const CTensor* implementationTe
numElements(std::vector<int32_t>(referenceTensor->shape, referenceTensor->shape + referenceTensor->num_dims));
TOSA_REF_REQUIRE(elementCount > 0, "[E] Invalid shape for reference tensor");
+ TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64");
+ const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
+ TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference");
+
switch (implementationTensor->data_type)
{
case tosa_datatype_fp32_t: {
- TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64");
- const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
- TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference");
const auto* impData = reinterpret_cast<const float*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation");
auto result = std::equal(refData, std::next(refData, elementCount), impData,
- std::next(impData, elementCount), exact_fp32);
+ std::next(impData, elementCount), exact_fp<float>);
+ return result;
+ }
+ case tosa_datatype_fp16_t: {
+ const auto* impData = reinterpret_cast<const half_float::half*>(implementationTensor->data);
+ TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation");
+ auto result = std::equal(refData, std::next(refData, elementCount), impData,
+ std::next(impData, elementCount), exact_fp<half_float::half>);
return result;
}
default: