aboutsummaryrefslogtreecommitdiff
path: root/support/Bfloat16.h
diff options
context:
space:
mode:
Diffstat (limited to 'support/Bfloat16.h')
-rw-r--r--support/Bfloat16.h47
1 files changed, 26 insertions, 21 deletions
diff --git a/support/Bfloat16.h b/support/Bfloat16.h
index 173f2d16e2..02772898a8 100644
--- a/support/Bfloat16.h
+++ b/support/Bfloat16.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,12 +21,12 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_BFLOAT16_H
-#define ARM_COMPUTE_BFLOAT16_H
+#ifndef ACL_SUPPORT_BFLOAT16_H
+#define ACL_SUPPORT_BFLOAT16_H
#include <cstdint>
#include <cstring>
-
+#include <ostream>
namespace arm_compute
{
namespace
@@ -40,25 +40,24 @@ namespace
inline uint16_t float_to_bf16(const float v)
{
const uint32_t *fromptr = reinterpret_cast<const uint32_t *>(&v);
-#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+#if defined(ARM_COMPUTE_ENABLE_BF16)
uint16_t res;
- __asm __volatile(
- "ldr s0, [%[fromptr]]\n"
- ".inst 0x1e634000\n" // BFCVT h0, s0
- "str h0, [%[toptr]]\n"
- :
- : [fromptr] "r"(fromptr), [toptr] "r"(&res)
- : "v0", "memory");
-#else /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+ __asm __volatile("ldr s0, [%[fromptr]]\n"
+ ".inst 0x1e634000\n" // BFCVT h0, s0
+ "str h0, [%[toptr]]\n"
+ :
+ : [fromptr] "r"(fromptr), [toptr] "r"(&res)
+ : "v0", "memory");
+#else /* defined(ARM_COMPUTE_ENABLE_BF16) */
uint16_t res = (*fromptr >> 16);
const uint16_t error = (*fromptr & 0x0000ffff);
uint16_t bf_l = res & 0x0001;
- if((error > 0x8000) || ((error == 0x8000) && (bf_l != 0)))
+ if ((error > 0x8000) || ((error == 0x8000) && (bf_l != 0)))
{
res += 1;
}
-#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
return res;
}
@@ -75,23 +74,21 @@ inline float bf16_to_float(const uint16_t &v)
memcpy(&fp, &lv, sizeof(lv));
return fp;
}
-}
+} // namespace
/** Brain floating point representation class */
class bfloat16 final
{
public:
/** Default Constructor */
- bfloat16()
- : value(0)
+ bfloat16() : value(0)
{
}
/** Constructor
*
* @param[in] v Floating-point value
*/
- explicit bfloat16(float v)
- : value(float_to_bf16(v))
+ bfloat16(float v) : value(float_to_bf16(v))
{
}
/** Assignment operator
@@ -134,8 +131,16 @@ public:
return val;
}
+ bfloat16 &operator+=(float v)
+ {
+ value = float_to_bf16(bf16_to_float(value) + v);
+ return *this;
+ }
+
+ friend std::ostream &operator<<(std::ostream &os, const bfloat16 &arg);
+
private:
uint16_t value;
};
} // namespace arm_compute
-#endif /* ARM_COMPUTE_BFLOAT16_H */
+#endif // ACL_SUPPORT_BFLOAT16_H