diff options
author | James Ward <james.ward@arm.com> | 2022-11-15 11:36:47 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-11-29 15:55:26 +0000 |
commit | ee2566914d3476b8103b88915f3b81bda8490b44 (patch) | |
tree | 0d0dd56adafb3a65d896a192eeb736200eec8f06 /reference_model/src/ops/image.cc | |
parent | 542dd3b8da39440026fa9e809eebd0a3b79cf95d (diff) | |
download | reference_model-ee2566914d3476b8103b88915f3b81bda8490b44.tar.gz |
FP16 improvements
* Update FP16 resize to newest spec version
* Correct casting to fp16 for graphs of >1 ops
Change-Id: Iedff9a71eb7f72948b3c00a635bb0fd07d414bcd
Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'reference_model/src/ops/image.cc')
-rw-r--r-- | reference_model/src/ops/image.cc | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 66efee0..a1a4474 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -15,6 +15,7 @@ #include "image.h" #include "arith_util.h" +#include "half.hpp" #include <type_traits> @@ -159,7 +160,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() resize_t dy; resize_t dx; - if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) + if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { dy = (resize_t)(fy - iy); dx = (resize_t)(fx - ix); @@ -190,14 +192,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() acc += (OutEigenType)v10 * dy * (1.0 - dx); acc += (OutEigenType)v11 * dy * dx; } - else if ((typeid(resize_t) == typeid(Eigen::bfloat16))) + else if ((typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { - Eigen::bfloat16 bf16_acc; - bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx); - bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx; - bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx); - bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx; - acc = (float)bf16_acc; + resize_t f16_acc; + f16_acc = (resize_t)v00 * (resize_t)(1.0 - dy) * (resize_t)(1.0 - dx); + f16_acc += (resize_t)v01 * (resize_t)(1.0 - dy) * (resize_t)dx; + f16_acc += (resize_t)v10 * (resize_t)dy * (resize_t)(1.0 - dx); + f16_acc += (resize_t)v11 * (resize_t)dy * (resize_t)dx; + acc = (float)f16_acc; } else { @@ -210,7 +213,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() else { ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode"); - if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) + if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { iy = (dy >= 0.5) ? iy1 : iy0; ix = (dx >= 0.5) ? ix1 : ix0; @@ -236,6 +240,6 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); -DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); +DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, half_float::half); DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float); |