From ee2566914d3476b8103b88915f3b81bda8490b44 Mon Sep 17 00:00:00 2001 From: James Ward Date: Tue, 15 Nov 2022 11:36:47 +0000 Subject: FP16 improvements * Update FP16 resize to newest spec version * Correct casting to fp16 for graphs of >1 ops Change-Id: Iedff9a71eb7f72948b3c00a635bb0fd07d414bcd Signed-off-by: James Ward --- reference_model/src/arith_util.h | 15 ++++++++++----- reference_model/src/ops/image.cc | 24 ++++++++++++++---------- reference_model/src/ops/op_factory.h | 2 +- reference_model/src/tensor.cc | 5 ++--- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index 33bdeed..a75d7a3 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -30,17 +30,18 @@ #include #include #define __STDC_LIMIT_MACROS //enable min/max of plain data type -#include "func_debug.h" #include "func_config.h" +#include "func_debug.h" +#include "half.hpp" #include "inttypes.h" #include "tosa_generated.h" +#include +#include #include #include #include #include #include -#include -#include using namespace tosa; using namespace std; @@ -269,8 +270,12 @@ float fpTrunc(float f_in) truncateFloatToBFloat(&f_in, 1); break; case DType_FP16: - // TODO(jw): implement FP16 truncate function (no-op placeholder for now) - break; + // Cast to temporary float16 value before casting back to float32 + { + half_float::half h = half_float::half_cast(f_in); + f_in = half_float::half_cast(h); + break; + } case DType_FP32: // No-op for fp32 break; diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 66efee0..a1a4474 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -15,6 +15,7 @@ #include "image.h" #include "arith_util.h" +#include "half.hpp" #include @@ -159,7 +160,8 @@ int OpResize::eval() resize_t dy; resize_t dx; - if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) + if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { dy = (resize_t)(fy - iy); dx = (resize_t)(fx - ix); @@ -190,14 +192,15 @@ int OpResize::eval() acc += (OutEigenType)v10 * dy * (1.0 - dx); acc += (OutEigenType)v11 * dy * dx; } - else if ((typeid(resize_t) == typeid(Eigen::bfloat16))) + else if ((typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { - Eigen::bfloat16 bf16_acc; - bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx); - bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx; - bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx); - bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx; - acc = (float)bf16_acc; + resize_t f16_acc; + f16_acc = (resize_t)v00 * (resize_t)(1.0 - dy) * (resize_t)(1.0 - dx); + f16_acc += (resize_t)v01 * (resize_t)(1.0 - dy) * (resize_t)dx; + f16_acc += (resize_t)v10 * (resize_t)dy * (resize_t)(1.0 - dx); + f16_acc += (resize_t)v11 * (resize_t)dy * (resize_t)dx; + acc = (float)f16_acc; } else { @@ -210,7 +213,8 @@ int OpResize::eval() else { ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode"); - if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) + if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) { iy = (dy >= 0.5) ? iy1 : iy0; ix = (dx >= 0.5) ? ix1 : ix0; @@ -236,6 +240,6 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); -DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); +DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, half_float::half); DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index f399bd1..f4177db 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -108,7 +108,7 @@ #define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \ diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 4eaf21d..e9598c4 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -159,8 +159,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) switch (dtype) { case DType_FP16: - // Convert from fp16 to fp32 - //TODO(jw): remove this once we cast to fp16 in register_fcn/eval + // Convert from fp16 to fp32 so that fp16 values can be manipulated as float for (uint32_t i=0; i < elements; i++) { fdatabuf[i] = half_float::half_cast(f16databuf[i]); } @@ -277,7 +276,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(f16databuf); return 1; } - // Convert fp32 to fp16 + // Convert fp32 to fp16 so that output file contains valid fp16 data for (uint32_t i=0; i < elements; i++) { f16databuf[i] = half_float::half_cast(fdatabuf[i]); } -- cgit v1.2.1