// Copyright (c) 2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef TOSA_FLOAT_UTILS_H_ #define TOSA_FLOAT_UTILS_H_ #include #include #include #include #if defined(__cpp_lib_bit_cast) #include #endif // defined(__cpp_lib_bit_cast) namespace tosa { namespace float_support { struct hidden {}; #if defined(__cpp_lib_bit_cast) #define BITCAST_CONSTEXPR constexpr inline constexpr inline int32_t get_bits(const float& f) { return std::bit_cast(f); } constexpr inline float from_bits(const int32_t& i) { return std::bit_cast(i); } #else #define BITCAST_CONSTEXPR inline union ufloat32 { constexpr ufloat32(const float& x) : f(x) {} constexpr ufloat32(const int32_t& x) : i(x) {} float f; int32_t i; }; inline int32_t get_bits(const float& f) { return ufloat32(f).i; } inline float from_bits(const int32_t& i) { return ufloat32(i).f; } #endif } // namespace float_support template = true> class float_t { storage_t m_data = 0; public: static constexpr size_t n_exponent_bits = n_exp_bits; static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; /// \brief Construct a floating point type with the given bit /// representation. static constexpr float_t from_bits(storage_t bits) { return float_t(float_support::hidden(), bits); } /// \brief Construct a float from the given sign, exponent and significand /// bits. static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) { storage_t bits = pm ? 1 : 0; bits <<= n_exp_bits; bits |= e; bits <<= n_significand_bits; if (with_denorm || e) bits |= s; return float_t(float_support::hidden(), bits); } /// \brief (Hidden) Construct a float type from a given bit pattern constexpr float_t(const float_support::hidden&, storage_t bits) : m_data(bits) {} constexpr float_t() : m_data(0) {} constexpr float_t(const float_t& other) : m_data(other.m_data) {} /// \brief Cast to a different floating point representation. template constexpr inline operator float_t() const { using other_float_t = float_t; // Shortcut for types which are fundamentally similar (e.g., bf16 -> // fp32) if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && has_nan == other_has_nan) { return other_float_t::from_bits(static_cast(m_data) << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); } // Get initial values for the new floating point type const bool sign_bit = m_data < 0; int64_t new_exponent_bits = 0; uint64_t new_significand = 0; if (is_nan() || is_infinity()) { new_exponent_bits = (1 << other_n_exp_bits) - 1; if (is_nan()) { if constexpr (other_has_infinity) { // Copy across the `not_quiet bit`; set the LSB. Don't // attempt to copy across any of the rest of the payload. new_significand = 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); } else { new_significand = (1ul << other_float_t::n_significand_bits) - 1; } } else if constexpr (!other_has_infinity) { new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); } } else if (!is_zero()) { const int64_t this_exponent_bits = exponent_bits(); { constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); } new_significand = this->significand() << (64 - n_significand_bits); // Normalise subnormals if (this_exponent_bits == 0) { // Shift the most-significant 1 out of the magnitude to convert // it to a significand. Modify the exponent accordingly. uint8_t shift = __builtin_clzl(new_significand) + 1; new_exponent_bits -= shift; new_significand <<= shift; } // Align the significand for the output type uint32_t shift = 64 - other_float_t::n_significand_bits; const bool other_is_subnormal = new_exponent_bits <= 0; if (other_is_subnormal) { shift += 1 - new_exponent_bits; new_exponent_bits = 0; } const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); new_significand = shift == 64 ? 0 : new_significand >> shift; // Reinsert the most-significant-one if this is a subnormal in the // output type. new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); // Apply rounding based on the bits shifted out of the significand const uint64_t shift_half = 1ll << (shift - 1); if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) { new_significand += 1; // Handle the case that the significand overflowed due to // rounding constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; if (new_significand > max_significand) { new_significand = 0; new_exponent_bits++; } } // Saturate to infinity if the exponent is larger than can be // represented in the output type. This can only occur if the size // of the exponent of the new type is not greater than the exponent // of the old type. if constexpr (other_n_exp_bits <= n_exp_bits) { constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; if (new_exponent_bits >= inf_exp_bits) { new_exponent_bits = inf_exp_bits; new_significand = other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); } } } return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); } /// \brief Convert from a 32-bit floating point value BITCAST_CONSTEXPR float_t(const float& f) { // If this format exactly represents the binary32 format then get // the bits from the provided float; otherwise get a binary32 // representation and then convert to this format. if constexpr (represents_binary32()) m_data = float_support::get_bits(f); else m_data = static_cast>( static_cast>(f)) .m_data; } /// \brief Cast to a 32-bit floating point value BITCAST_CONSTEXPR operator float() const { // If this format exactly represents the binary32 format then return // a float; otherwise get a binary32 representation and then return // a float. if constexpr (represents_binary32()) return float_support::from_bits(m_data); else return static_cast(this->operator float_t()); } /// \brief Return whether this type represents the IEEE754 binary32 /// format constexpr static inline bool represents_binary32() { return std::is_same_v && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; } constexpr auto operator-() const { return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); } constexpr bool is_subnormal() const { return exponent_bits() == 0 && significand() != 0; } constexpr bool is_zero() const { return exponent_bits() == 0 && significand() == 0; } constexpr bool is_nan() const { return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && ((with_infinity && significand()) || (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); } constexpr bool is_infinity() const { return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); } constexpr inline const storage_t& bits() const { return m_data; } /// \brief Get the exponent constexpr inline int64_t exponent() const { return std::max(exponent_bits(), 1ul) - exponent_bias; } /// \brief Get the bits from the exponent constexpr inline uint64_t exponent_bits() const { constexpr uint64_t mask = (1ul << n_exp_bits) - 1; return (m_data >> n_significand_bits) & mask; } constexpr inline uint64_t significand() const { return m_data & ((1ul << n_significand_bits) - 1); } constexpr inline bool operator==(const float_t& other) const { return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); } constexpr inline float_t& operator+=(const float_t& rhs) { this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); return *this; } }; // This should probably be exported so we can use it elsewhere #undef BITCAST_CONSTEXPR namespace float_support { // Pre-C++23 these can't be computed as constexpr, so have to hardcode them template struct digits10; // floor(log10(2) * (digits - 1) template struct max_digits10; // ceil(log10(2) * digits + 1) template struct min_exponent10; // floor(log10(2) * min_exponent) template struct max_exponent10; // floor(log10(2) * max_exponent) template <> struct digits10<8> { constexpr static inline int value = 2; }; template <> struct max_digits10<8> { constexpr static inline int value = 4; }; template <> struct digits10<10> { constexpr static inline int value = 2; }; template <> struct max_digits10<10> { constexpr static inline int value = 5; }; template <> struct digits10<24> { constexpr static inline int value = 6; }; template <> struct max_digits10<24> { constexpr static inline int value = 9; }; template <> struct min_exponent10<-13> { constexpr static inline int value = -3; }; template <> struct max_exponent10<16> { constexpr static inline int value = 4; }; template <> struct min_exponent10<-125> { constexpr static inline int value = -37; }; template <> struct max_exponent10<128> { constexpr static inline int value = 38; }; template inline constexpr int digits10_v = digits10::value; template inline constexpr int max_digits10_v = max_digits10::value; template inline constexpr int min_exponent10_v = min_exponent10::value; template inline constexpr int max_exponent10_v = max_exponent10::value; } // namespace float_support } // namespace tosa namespace std { template struct is_floating_point> : std::integral_constant {}; template class numeric_limits> { using this_float_t = tosa::float_t; public: static constexpr bool is_specialized = true; static constexpr auto min() noexcept { return this_float_t::from_bits(false, 1, 0); } static constexpr auto max() noexcept { return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, (1 << this_float_t::n_significand_bits) - 1); } static constexpr auto lowest() noexcept { return -max(); } static constexpr int digits = this_float_t::n_significand_bits + 1; static constexpr int digits10 = tosa::float_support::digits10_v; static constexpr int max_digits10 = tosa::float_support::max_digits10_v; static constexpr bool is_signed = true; static constexpr bool is_integer = false; static constexpr bool is_exact = false; static constexpr int radix = 2; static constexpr auto epsilon() noexcept { return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); } static constexpr auto round_error() noexcept { return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); } static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v; static constexpr int max_exponent = this_float_t::exponent_bias + 1; static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v; static constexpr bool has_infinity = with_inf; static constexpr bool has_quiet_NaN = has_nan; static constexpr bool has_signaling_NaN = true; static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; static constexpr bool has_denorm_loss = false; static constexpr auto infinity() noexcept { if constexpr (with_inf) { return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); } else { return this_float_t::from_bits(false, 0, 0); } } static constexpr auto quiet_NaN() noexcept { return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1 << (this_float_t::n_significand_bits - 1) | 1); } static constexpr auto signaling_NaN() noexcept { return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); } static constexpr auto denorm_min() noexcept { return this_float_t::from_bits(false, 0, 1); } static constexpr bool is_iec559 = false; static constexpr bool is_bounded = false; static constexpr bool is_modulo = false; static constexpr bool traps = false; static constexpr bool tinyness_before = false; static constexpr float_round_style round_style = round_to_nearest; }; } // namespace std #endif // TOSA_FLOAT_UTILS_H_