aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/image.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r--reference_model/src/ops/image.cc24
1 files changed, 14 insertions, 10 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 66efee0..a1a4474 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -15,6 +15,7 @@
#include "image.h"
#include "arith_util.h"
+#include "half.hpp"
#include <type_traits>
@@ -159,7 +160,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
{
dy = (resize_t)(fy - iy);
dx = (resize_t)(fx - ix);
@@ -190,14 +192,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
acc += (OutEigenType)v10 * dy * (1.0 - dx);
acc += (OutEigenType)v11 * dy * dx;
}
- else if ((typeid(resize_t) == typeid(Eigen::bfloat16)))
+ else if ((typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
{
- Eigen::bfloat16 bf16_acc;
- bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx);
- bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx;
- bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx);
- bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx;
- acc = (float)bf16_acc;
+ resize_t f16_acc;
+ f16_acc = (resize_t)v00 * (resize_t)(1.0 - dy) * (resize_t)(1.0 - dx);
+ f16_acc += (resize_t)v01 * (resize_t)(1.0 - dy) * (resize_t)dx;
+ f16_acc += (resize_t)v10 * (resize_t)dy * (resize_t)(1.0 - dx);
+ f16_acc += (resize_t)v11 * (resize_t)dy * (resize_t)dx;
+ acc = (float)f16_acc;
}
else
{
@@ -210,7 +213,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
else
{
ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode");
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
{
iy = (dy >= 0.5) ? iy1 : iy0;
ix = (dx >= 0.5) ? ix1 : ix0;
@@ -236,6 +240,6 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, half_float::half);
DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);