aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-11-15 11:36:47 +0000
committerEric Kunze <eric.kunze@arm.com>2022-11-29 15:55:26 +0000
commitee2566914d3476b8103b88915f3b81bda8490b44 (patch)
tree0d0dd56adafb3a65d896a192eeb736200eec8f06
parent542dd3b8da39440026fa9e809eebd0a3b79cf95d (diff)
downloadreference_model-ee2566914d3476b8103b88915f3b81bda8490b44.tar.gz
FP16 improvements
* Update FP16 resize to newest spec version * Correct casting to fp16 for graphs of >1 ops Change-Id: Iedff9a71eb7f72948b3c00a635bb0fd07d414bcd Signed-off-by: James Ward <james.ward@arm.com>
-rw-r--r--reference_model/src/arith_util.h15
-rw-r--r--reference_model/src/ops/image.cc24
-rw-r--r--reference_model/src/ops/op_factory.h2
-rw-r--r--reference_model/src/tensor.cc5
4 files changed, 27 insertions, 19 deletions
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index 33bdeed..a75d7a3 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -30,17 +30,18 @@
#include <fenv.h>
#include <math.h>
#define __STDC_LIMIT_MACROS //enable min/max of plain data type
-#include "func_debug.h"
#include "func_config.h"
+#include "func_debug.h"
+#include "half.hpp"
#include "inttypes.h"
#include "tosa_generated.h"
+#include <Eigen/Core>
+#include <bitset>
#include <cassert>
#include <iostream>
#include <limits>
#include <stdint.h>
#include <typeinfo>
-#include <Eigen/Core>
-#include <bitset>
using namespace tosa;
using namespace std;
@@ -269,8 +270,12 @@ float fpTrunc(float f_in)
truncateFloatToBFloat(&f_in, 1);
break;
case DType_FP16:
- // TODO(jw): implement FP16 truncate function (no-op placeholder for now)
- break;
+ // Cast to temporary float16 value before casting back to float32
+ {
+ half_float::half h = half_float::half_cast<half_float::half, float>(f_in);
+ f_in = half_float::half_cast<float, half_float::half>(h);
+ break;
+ }
case DType_FP32:
// No-op for fp32
break;
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 66efee0..a1a4474 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -15,6 +15,7 @@
#include "image.h"
#include "arith_util.h"
+#include "half.hpp"
#include <type_traits>
@@ -159,7 +160,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
{
dy = (resize_t)(fy - iy);
dx = (resize_t)(fx - ix);
@@ -190,14 +192,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
acc += (OutEigenType)v10 * dy * (1.0 - dx);
acc += (OutEigenType)v11 * dy * dx;
}
- else if ((typeid(resize_t) == typeid(Eigen::bfloat16)))
+ else if ((typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
{
- Eigen::bfloat16 bf16_acc;
- bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx);
- bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx;
- bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx);
- bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx;
- acc = (float)bf16_acc;
+ resize_t f16_acc;
+ f16_acc = (resize_t)v00 * (resize_t)(1.0 - dy) * (resize_t)(1.0 - dx);
+ f16_acc += (resize_t)v01 * (resize_t)(1.0 - dy) * (resize_t)dx;
+ f16_acc += (resize_t)v10 * (resize_t)dy * (resize_t)(1.0 - dx);
+ f16_acc += (resize_t)v11 * (resize_t)dy * (resize_t)dx;
+ acc = (float)f16_acc;
}
else
{
@@ -210,7 +213,8 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
else
{
ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode");
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
{
iy = (dy >= 0.5) ? iy1 : iy0;
ix = (dx >= 0.5) ? ix1 : ix0;
@@ -236,6 +240,6 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, half_float::half);
DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index f399bd1..f4177db 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -108,7 +108,7 @@
#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 4eaf21d..e9598c4 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -159,8 +159,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
switch (dtype)
{
case DType_FP16:
- // Convert from fp16 to fp32
- //TODO(jw): remove this once we cast to fp16 in register_fcn/eval
+ // Convert from fp16 to fp32 so that fp16 values can be manipulated as float
for (uint32_t i=0; i < elements; i++) {
fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
}
@@ -277,7 +276,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(f16databuf);
return 1;
}
- // Convert fp32 to fp16
+ // Convert fp32 to fp16 so that output file contains valid fp16 data
for (uint32_t i=0; i < elements; i++) {
f16databuf[i] = half_float::half_cast<half_float::half, float>(fdatabuf[i]);
}