diff options
Diffstat (limited to 'support/Bfloat16.h')
-rw-r--r-- | support/Bfloat16.h | 54 |
1 files changed, 30 insertions, 24 deletions
diff --git a/support/Bfloat16.h b/support/Bfloat16.h index d57d8ce9ee..02772898a8 100644 --- a/support/Bfloat16.h +++ b/support/Bfloat16.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2020-2022,2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,11 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_BFLOAT16_H -#define ARM_COMPUTE_BFLOAT16_H +#ifndef ACL_SUPPORT_BFLOAT16_H +#define ACL_SUPPORT_BFLOAT16_H #include <cstdint> - +#include <cstring> +#include <ostream> namespace arm_compute { namespace @@ -39,25 +40,24 @@ namespace inline uint16_t float_to_bf16(const float v) { const uint32_t *fromptr = reinterpret_cast<const uint32_t *>(&v); -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) uint16_t res; - __asm __volatile( - "ldr s0, [%[fromptr]]\n" - ".inst 0x1e634000\n" // BFCVT h0, s0 - "str h0, [%[toptr]]\n" - : - : [fromptr] "r"(fromptr), [toptr] "r"(&res) - : "v0", "memory"); -#else /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ + __asm __volatile("ldr s0, [%[fromptr]]\n" + ".inst 0x1e634000\n" // BFCVT h0, s0 + "str h0, [%[toptr]]\n" + : + : [fromptr] "r"(fromptr), [toptr] "r"(&res) + : "v0", "memory"); +#else /* defined(ARM_COMPUTE_ENABLE_BF16) */ uint16_t res = (*fromptr >> 16); const uint16_t error = (*fromptr & 0x0000ffff); uint16_t bf_l = res & 0x0001; - if((error > 0x8000) || ((error == 0x8000) && (bf_l != 0))) + if ((error > 0x8000) || ((error == 0x8000) && (bf_l != 0))) { res += 1; } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ return res; } @@ -70,27 +70,25 @@ inline uint16_t float_to_bf16(const float v) inline float bf16_to_float(const uint16_t &v) { const uint32_t lv = (v << 16); - const float *fp = reinterpret_cast<const float *>(&lv); - - return *fp; -} + float fp; + memcpy(&fp, &lv, sizeof(lv)); + return fp; } +} // namespace /** Brain floating point representation class */ class bfloat16 final { public: /** Default Constructor */ - bfloat16() - : value(0) + bfloat16() : value(0) { } /** Constructor * * @param[in] v Floating-point value */ - explicit bfloat16(float v) - : value(float_to_bf16(v)) + bfloat16(float v) : value(float_to_bf16(v)) { } /** Assignment operator @@ -133,8 +131,16 @@ public: return val; } + bfloat16 &operator+=(float v) + { + value = float_to_bf16(bf16_to_float(value) + v); + return *this; + } + + friend std::ostream &operator<<(std::ostream &os, const bfloat16 &arg); + private: uint16_t value; }; } // namespace arm_compute -#endif /* ARM_COMPUTE_BFLOAT16_H */
\ No newline at end of file +#endif // ACL_SUPPORT_BFLOAT16_H |