diff options
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/image.cc | 24 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.h | 2 |
2 files changed, 15 insertions, 11 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 66efee0..a1a4474 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -15,6 +15,7 @@ #include "image.h" #include "arith_util.h" +#include "half.hpp" #include <type_traits> @@ -159,7 +160,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() resize_t dy; resize_t dx; - if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) + if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { dy = (resize_t)(fy - iy); dx = (resize_t)(fx - ix); @@ -190,14 +192,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() acc += (OutEigenType)v10 * dy * (1.0 - dx); acc += (OutEigenType)v11 * dy * dx; } - else if ((typeid(resize_t) == typeid(Eigen::bfloat16))) + else if ((typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { - Eigen::bfloat16 bf16_acc; - bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx); - bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx; - bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx); - bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx; - acc = (float)bf16_acc; + resize_t f16_acc; + f16_acc = (resize_t)v00 * (resize_t)(1.0 - dy) * (resize_t)(1.0 - dx); + f16_acc += (resize_t)v01 * (resize_t)(1.0 - dy) * (resize_t)dx; + f16_acc += (resize_t)v10 * (resize_t)dy * (resize_t)(1.0 - dx); + f16_acc += (resize_t)v11 * (resize_t)dy * (resize_t)dx; + acc = (float)f16_acc; } else { @@ -210,7 +213,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() else { ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode"); - if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) + if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { iy = (dy >= 0.5) ? iy1 : iy0; ix = (dx >= 0.5) ? ix1 : ix0; @@ -236,6 +240,6 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); -DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); +DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, half_float::half); DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index f399bd1..f4177db 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -108,7 +108,7 @@ #define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \ + return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \ |