From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- reference_model/src/arith_util.h | 89 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) (limited to 'reference_model/src/arith_util.h') diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index 554a7a2..33bdeed 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -31,13 +31,18 @@ #include #define __STDC_LIMIT_MACROS //enable min/max of plain data type #include "func_debug.h" +#include "func_config.h" #include "inttypes.h" +#include "tosa_generated.h" #include #include #include #include #include +#include +#include +using namespace tosa; using namespace std; inline size_t _count_one(uint64_t val) @@ -191,4 +196,88 @@ constexpr T saturate(const uint32_t width, const intmax_t value) // clang-format on } +inline void float_trunc_bytes(float* src) +{ + /* Set the least significant two bytes to zero for the input float value.*/ + char src_as_bytes[sizeof(float)]; + memcpy(src_as_bytes, src, sizeof(float)); + + if (g_func_config.float_is_big_endian) + { + src_as_bytes[2] = '\000'; + src_as_bytes[3] = '\000'; + } + else + { + src_as_bytes[0] = '\000'; + src_as_bytes[1] = '\000'; + } + + memcpy(src, &src_as_bytes, sizeof(float)); +} + +inline void truncateFloatToBFloat(float* src, int64_t size) { + /* Set the least significant two bytes to zero for each float + value in the input src buffer. */ + ASSERT_MEM(src); + ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer."); + for (; size != 0; src++, size--) + { + float_trunc_bytes(src); + } +} + +inline bool checkValidBFloat(float src) +{ + /* Checks if the least significant two bytes are zero. */ + ASSERT_MEM(src); + char src_as_bytes[sizeof(float)]; + memcpy(src_as_bytes, &src, sizeof(float)); + + if (g_func_config.float_is_big_endian) + { + return (src_as_bytes[2] == '\000' && src_as_bytes[3] == '\000'); + } + else + { + return (src_as_bytes[0] == '\000' && src_as_bytes[1] == '\000'); + } +} + +inline bool float_is_big_endian() +{ + /* Compares float values 1.0 and -1.0 by checking whether the + negation causes the first or the last byte to change. + First byte changing would indicate the float representation + is big-endian.*/ + float f = 1.0; + char f_as_bytes[sizeof(float)]; + memcpy(f_as_bytes, &f, sizeof(float)); + f = -f; + char f_neg_as_bytes[sizeof(float)]; + memcpy(f_neg_as_bytes, &f, sizeof(float)); + return f_as_bytes[0] != f_neg_as_bytes[0]; +} + +template +float fpTrunc(float f_in) +{ + /* Truncates a float value based on the DType it represents.*/ + switch (Dtype) + { + case DType_BF16: + truncateFloatToBFloat(&f_in, 1); + break; + case DType_FP16: + // TODO(jw): implement FP16 truncate function (no-op placeholder for now) + break; + case DType_FP32: + // No-op for fp32 + break; + default: + ASSERT_MSG(false, "DType %s should not be float-truncated.", EnumNameDType(Dtype)); + } + return f_in; +} + #endif /* _ARITH_UTIL_H */ -- cgit v1.2.1