From 520b7ca51f1aa2835d45ca7266a07b4028d449d2 Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Fri, 19 Apr 2024 14:21:00 +0000 Subject: Update float8 code to support non-saturating mode Signed-off-by: Won Jeon Change-Id: I786aca0a2f137cebd446a3a71c8d6fe186286957 --- CMakeLists.txt | 2 +- include/cfloat.h | 837 +++++++++++++++++++++++++++++++++++ include/float_utils.h | 533 ---------------------- include/tosa_serialization_handler.h | 2 +- src/tosa_serialization_handler.cpp | 4 +- 5 files changed, 841 insertions(+), 537 deletions(-) create mode 100644 include/cfloat.h delete mode 100644 include/float_utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f4f851..679603d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ set(public_headers) list(APPEND public_headers include/attribute.h include/attribute.def - include/float_utils.h + include/cfloat.h include/numpy_utils.h include/tosa_generated.h include/tosa_serialization_handler.h diff --git a/include/cfloat.h b/include/cfloat.h new file mode 100644 index 0000000..0cf4896 --- /dev/null +++ b/include/cfloat.h @@ -0,0 +1,837 @@ +// Copyright (c) 2022-2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CT_CFLOAT_H +#define CT_CFLOAT_H +#include +#include +#include +#include +#include +#if defined(__cpp_lib_bit_cast) +#include +#endif // defined(__cpp_lib_bit_cast) + +namespace ct +{ +/// \brief Bitfield specification of the features provided of a specified +/// floating point type. +enum class FloatFeatures +{ + None = 0x0, + HasNaN = 0x1, ///< The type can represent NaN values + HasInf = 0x2, ///< The type can represent Infinity + HasDenorms = 0x4, ///< The type can represent denormal/subnormal values +}; + +constexpr FloatFeatures operator&(const FloatFeatures& a, const FloatFeatures& b) +{ + using T = std::underlying_type_t; + return static_cast(static_cast(a) & static_cast(b)); +} + +constexpr FloatFeatures operator|(const FloatFeatures& a, const FloatFeatures& b) +{ + using T = std::underlying_type_t; + return static_cast(static_cast(a) | static_cast(b)); +} + +constexpr FloatFeatures& operator|=(FloatFeatures& a, const FloatFeatures& b) +{ + a = a | b; + return a; +} + +namespace float_support +{ +struct hidden +{}; + +/// \brief Get the number of bytes required to store the given number of +/// bits. +/// +/// NOTE This is distinct from the number of bytes required to represent +/// the number of bits - a power of two number of bytes will always be +/// returned by this method. +constexpr size_t get_storage_bytes(const size_t n_bits) +{ + const size_t n_bytes = (n_bits + 7) / 8; + size_t storage_bytes = 1; + for (; storage_bytes < n_bytes; storage_bytes <<= 1) + ; + return storage_bytes; +} + +/// \brief Utility method to convert from an older representation of the +/// floating-point features to the FloatFeatures bitfield. +constexpr FloatFeatures get_float_flags(bool has_nan, bool has_denorm, bool has_inf) +{ + FloatFeatures r = FloatFeatures::None; + + if (has_nan) + r |= FloatFeatures::HasNaN; + + if (has_denorm) + r |= FloatFeatures::HasDenorms; + + if (has_inf) + r |= FloatFeatures::HasInf; + + return r; +} + +/// \brief Shorthand for all support features +static constexpr FloatFeatures AllFeats = get_float_flags(true, true, true); + +// Map from a number of storage bytes to a suitable storage type +template +struct storage_type; + +#define STORAGE_TYPE(T) \ + template <> \ + struct storage_type \ + { \ + using type = T; \ + } +STORAGE_TYPE(int8_t); +STORAGE_TYPE(int16_t); +STORAGE_TYPE(int32_t); +STORAGE_TYPE(int64_t); +#undef STORAGE_TYPE + +template +using storage_type_t = typename storage_type::type; + +#if defined(__cpp_lib_bit_cast) +#define BITCAST_CONSTEXPR constexpr inline + +// If bit_cast is available then use it + +constexpr inline int32_t get_bits(const float& f) +{ + return std::bit_cast(f); +} +constexpr inline float from_bits(const int32_t& i) +{ + return std::bit_cast(i); +} + +#else +#define BITCAST_CONSTEXPR inline + +// Otherwise `memcpy` is the safe (non-UB) of achieving the same result + +inline int32_t get_bits(const float& f) +{ + int32_t i; + std::memcpy(&i, &f, sizeof(float)); + return i; +} + +inline float from_bits(const int32_t& i) +{ + float f; + std::memcpy(&f, &i, sizeof(float)); + return f; +} +#endif + +} // namespace float_support + +/// \brief Overflow mode for narrowing floating-point casts. +/// +/// Determine the behaviour for values which cannot be represented by the +/// destination type. +enum class OverflowMode +{ + Saturate, ///< Map to the largest representable value + Overflow ///< Map to infinity (if available) or NaN +}; + +/// Functor for casting cfloat_advanced +/// +/// Specific casting behavior can be specified when constructing the +/// functor. +/// +/// By default, OVERFLOW mode is used when the destination type has either +/// infinity or NaN representations. Otherwise SATURATE mode is used. It is +/// illegal to specify OVERFLOW mode for a type which has neither infinity +/// or NaN representations - this will result in a compilation error. +template +class cfloat_cast +{ + constexpr static FloatFeatures in_feats = in_type::features; + constexpr static FloatFeatures out_feats = out_type::features; + constexpr static size_t in_bits = in_type::n_bits; + constexpr static size_t in_exp_bits = in_type::n_exponent_bits; + constexpr static size_t out_bits = out_type::n_bits; + constexpr static size_t out_exp_bits = out_type::n_exponent_bits; + +public: + constexpr cfloat_cast() + { + // SATURATE mode MUST be specified if the destination type does not + // have either NaN or infinity representations. + static_assert(overflow_mode == OverflowMode::Saturate || out_type::has_nan || out_type::has_inf); + } + + /// \brief Cast from `in` to the given `out_type` + // + // This code relies on an understanding of the storage format used by + // `cfloat_advanced`. See the documentation of that class for further + // details. + constexpr out_type operator()(const in_type& in) const + { + // Shortcut for types which differ only in the number of significand + // bits, and where the output type is wider than the input type. For + // example, bfloat16 and binary32. + if constexpr (in_exp_bits == out_exp_bits && out_bits >= in_bits && in_feats == out_feats) + { + return out_type::from_bits(static_cast(in.bits()) << (out_bits - in_bits)); + } + + // Get initial values for the new floating point type + const bool sign_bit = in.sign(); + int64_t new_exponent_bits = 0; + uint64_t new_significand = 0; + + if (in.is_nan() || in.is_infinity()) + { + new_exponent_bits = (UINT64_C(1) << out_exp_bits) - 1; + + if (in.is_nan()) + { + if constexpr (out_type::has_inf) + { + // Copy across the `not_quiet bit`; set the LSB. + // Don't attempt to copy across any of the rest of + // the payload. + new_significand = 0x1 | (((in.significand() >> (in_type::n_significand_bits - 1)) & 1) + << out_type::n_significand_bits); + } + else + { + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + } + } + else if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Saturate) + { + new_exponent_bits -= 1; + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + } + else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Saturate) + { + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1); + } + else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Overflow) + { + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + } + } + else if (!in.is_zero()) + { + const int64_t this_exponent_bits = in.exponent_bits(); + { + constexpr int64_t exponent_rebias = out_type::exponent_bias - in_type::exponent_bias; + new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); + } + new_significand = in.significand() << (64 - in_type::n_significand_bits); + + // Normalise subnormals + if (this_exponent_bits == 0) + { + // Shift the most-significant 1 out of the magnitude to + // convert it to a significand. Modify the exponent + // accordingly. + uint8_t shift = __builtin_clzl(new_significand) + 1; + new_exponent_bits -= shift; + new_significand <<= shift; + } + + // Apply overflow to out-of-range values; this must occur before + // rounding, as out-of-range values could be rounded down to the + // largest representable value. + if constexpr (overflow_mode == OverflowMode::Overflow) + { + // Determine the maximum value of exponent, and unrounded + // significand. + constexpr bool inf_and_nan = out_type::has_nan && out_type::has_inf; + constexpr int64_t max_exp_bits = (INT64_C(1) << out_exp_bits) - (inf_and_nan ? 2 : 1); + constexpr uint64_t max_significand = + ((UINT64_C(1) << out_type::n_significand_bits) - (inf_and_nan ? 1 : 2)) + << (64 - out_type::n_significand_bits); + + // If the exponent is strictly larger than the largest + // possible, or the exponent is equal to the largest + // possible AND the (unrounded) significand is strictly + // larger than the largest possible then return an + // appropriate overflow value. + if (new_exponent_bits > max_exp_bits || + (new_exponent_bits == max_exp_bits && new_significand > max_significand)) + { + if constexpr (out_type::has_inf) + return out_type::infinity(sign_bit); + else + return out_type::NaN(); + } + } + + // Align the significand for the output type + uint32_t shift = 64 - out_type::n_significand_bits; + const bool other_is_subnormal = new_exponent_bits <= 0; + if (other_is_subnormal) + { + shift += 1 - new_exponent_bits; + new_exponent_bits = 0; + } + + const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((UINT64_C(1) << shift) - 1); + new_significand = shift == 64 ? 0 : new_significand >> shift; + + // Reinsert the most-significant-one if this is a subnormal + // in the output type. + new_significand |= (other_is_subnormal ? UINT64_C(1) : 0) << (64 - shift); + + // Apply rounding based on the bits shifted out of the + // significand + const uint64_t shift_half = UINT64_C(1) << (shift - 1); + if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) + { + new_significand += 1; + + // Handle the case that the significand overflowed due + // to rounding + constexpr uint64_t max_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + if (new_significand > max_significand) + { + new_significand = 0; + new_exponent_bits++; + } + } + + // Saturate or overflow if the value is larger than can be + // represented in the output type. This can only occur if the + // size of the exponent of the new type is not greater than the + // exponent of the old type. + if constexpr (out_exp_bits <= in_exp_bits) + { + constexpr int64_t inf_exp_bits = (INT64_C(1) << out_exp_bits) - 1; + if (new_exponent_bits >= inf_exp_bits) + { + if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Overflow) + { + // If the output type has a representation of + // infinity, and we are in OVERFLOW Mode, then + // return infinity. + new_exponent_bits = inf_exp_bits; + new_significand = 0; + } + else if constexpr (out_type::has_inf) + { + // If the output type has a representation of + // infinity, and we are in SATURATE mode, then + // return the largest representable real number. + new_exponent_bits = inf_exp_bits - 1; + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + } + else if (new_exponent_bits > inf_exp_bits) + { + if constexpr (overflow_mode == OverflowMode::Overflow) + return out_type::NaN(); + else + return out_type::max(sign_bit); + } + else + { + constexpr uint64_t max_significand = + (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1); + if (new_significand > max_significand) + { + if constexpr (overflow_mode == OverflowMode::Saturate) + new_significand = max_significand; + else + return out_type::NaN(); + } + } + } + } + } + + return out_type::from_bits(sign_bit, new_exponent_bits, new_significand); + } +}; + +/// \brief Bit-accurate representation storage of IEEE754 compliant and +/// derived floating point types. +/// +/// Template parameters allow for specification of the number of bits, the +/// number of exponent bits, and the features of the floating point types. +/// The number of significand bits is `n_bits - n_exponent_bits - 1`. It is +/// not possible to represent a signless type, such as FP8 E8M0. +/// +/// For an imaginary 7-bit type, FP7 E4M2; the storage for various values +/// given different floating point features is given below: +/// +/// Value All features No infinity No features +/// -------------------------- ------------ ----------- ----------- +/// Positive zero +0 00 0000 00 As before As before +/// Negative zero -0 11 0000 00 As before As before +/// Positive/negative infinity SS 1111 00 N/A N/A +/// Signalling NaN SS 1111 01 SS 1111 11 N/A +/// Quiet NaN SS 1111 11 N/A N/A +/// Largest normal SS 1110 11 SS 1111 10 SS 1111 11 +/// Smallest normal SS 0001 00 As before SS 0000 01 +/// Largest denormal SS 0000 11 SS 0000 11 N/A +/// +/// Note that the sign bit is extended to fill the storage type. +template +class cfloat_advanced +{ +public: + using storage_t = float_support::storage_type_t; + + static constexpr size_t n_bits = _n_bits; + static constexpr size_t n_exponent_bits = n_exp_bits; + static constexpr size_t n_significand_bits = n_bits - (1 + n_exp_bits); + static constexpr int64_t exponent_bias = (INT64_C(1) << (n_exp_bits - 1)) - 1; + + static constexpr FloatFeatures features = Feats; + static constexpr bool has_nan = (Feats & FloatFeatures::HasNaN) != FloatFeatures::None; + static constexpr bool has_inf = (Feats & FloatFeatures::HasInf) != FloatFeatures::None; + static constexpr bool has_denorms = (Feats & FloatFeatures::HasDenorms) != FloatFeatures::None; + + /// \brief Construct a floating point type with the given bit + /// representation. + static constexpr cfloat_advanced from_bits(storage_t bits) + { + return cfloat_advanced(float_support::hidden(), bits); + } + + /// \brief Construct a float from the given sign, exponent and + /// significand bits. + static constexpr cfloat_advanced from_bits(bool pm, storage_t e, storage_t s) + { + storage_t bits = pm ? -1 : 0; + + bits <<= n_exp_bits; + bits |= e; + + bits <<= n_significand_bits; + if (has_denorms || e) + bits |= s; + + return cfloat_advanced(float_support::hidden(), bits); + } + + /// \brief (Hidden) Construct a float type from a given bit pattern + constexpr cfloat_advanced(const float_support::hidden&, storage_t bits) + : m_data(bits) + {} + + constexpr cfloat_advanced() + : m_data(0) + {} + constexpr cfloat_advanced(const cfloat_advanced& other) + : m_data(other.m_data) + {} + + constexpr cfloat_advanced& operator=(const cfloat_advanced& other) + { + this->m_data = other.m_data; + return *this; + } + + constexpr cfloat_advanced& operator=(cfloat_advanced&& other) + { + this->m_data = other.m_data; + return *this; + } + + /// \brief Get a NaN representation + static constexpr cfloat_advanced NaN() + { + static_assert(has_nan); + + // NaN is always encoded with all 1s in the exponent. + // If Inf exists, then NaN is encoded as a non-zero significand; if + // Inf doesn't exist then NaN is encoded as all ones in the + // significand. + constexpr uint64_t exp_bits = (UINT64_C(1) << n_exponent_bits) - 1; + constexpr uint64_t sig_bits = has_inf ? 1 : (UINT64_C(1) << n_significand_bits) - 1; + return cfloat_advanced::from_bits(false, exp_bits, sig_bits); + } + + /// \brief Get a representation of infinity + static constexpr cfloat_advanced infinity(const bool& sign) + { + static_assert(has_inf); + + // Inf is always encoded with all 1s in the exponent, and all zeros + // in the significand. + return cfloat_advanced::from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, 0); + } + + /// \brief Get the largest representable value + static constexpr cfloat_advanced max(const bool& sign) + { + if constexpr (has_nan && has_inf) + { + // Where we have NaN and Infinity, exponents all `1` corresponds + // to some of these values. + return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1); + } + else if constexpr (has_nan || has_inf) + { + // Where we have either NaN or infinity (but not both), + // exponents all `1` AND significand all `1` corresponds to the + // special value. + return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2); + } + else + { + // With no special values to encode, the maximum value is + // encoded as all `1`s. + return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1); + } + } + + /// \brief Cast to a different floating point representation. + template + constexpr inline operator cfloat_advanced() const + { + using out_type = cfloat_advanced; + return cfloat_cast().operator()(*this); + } + + /// \brief Convert from a 32-bit floating point value + BITCAST_CONSTEXPR + cfloat_advanced(const float& f) + { + // If this format exactly represents the binary32 format then get + // the bits from the provided float; otherwise get a binary32 + // representation and then convert to this format. + if constexpr (represents_binary32()) + m_data = float_support::get_bits(f); + else + m_data = + static_cast>(static_cast>(f)).m_data; + } + + /// \brief Cast to a 32-bit floating point value + BITCAST_CONSTEXPR operator float() const + { + // If this format exactly represents the binary32 format then return + // a float; otherwise get a binary32 representation and then return + // a float. + if constexpr (represents_binary32()) + return float_support::from_bits(m_data); + else + return static_cast(this->operator cfloat_advanced<32, 8>()); + } + + /// \brief Return whether this type represents the IEEE754 binary32 + /// format + constexpr static inline bool represents_binary32() + { + return std::is_same_v && n_exp_bits == 8 && Feats == float_support::AllFeats; + } + + constexpr auto operator-() const + { + constexpr storage_t sign_bits = + static_cast(std::numeric_limits>::max() << (n_bits - 1)); + return from_bits(m_data ^ sign_bits); + } + + constexpr bool is_subnormal() const + { + return exponent_bits() == 0 && significand() != 0; + } + + constexpr bool is_zero() const + { + return exponent_bits() == 0 && significand() == 0; + } + + constexpr bool is_nan() const + { + return has_nan && (exponent_bits() == (UINT64_C(1) << n_exponent_bits) - 1) && + ((has_inf && significand()) || (!has_inf && significand() == (UINT64_C(1) << n_significand_bits) - 1)); + } + + constexpr bool is_infinity() const + { + return has_inf && ((exponent_bits() == (UINT64_C(1) << n_exponent_bits) - 1) && (significand() == 0)); + } + + constexpr inline const storage_t& bits() const + { + return m_data; + } + + /// \brief Get the exponent + constexpr inline int64_t exponent() const + { + return std::max(exponent_bits(), INT64_C(1)) - exponent_bias; + } + + /// \brief Get the sign bit + constexpr inline bool sign() const + { + return (m_data >> (n_bits - 1)) & 0x1; + } + + /// \brief Get the bits from the exponent + constexpr inline uint64_t exponent_bits() const + { + constexpr uint64_t mask = (UINT64_C(1) << n_exp_bits) - 1; + return (m_data >> n_significand_bits) & mask; + } + + constexpr inline uint64_t significand() const + { + return m_data & ((UINT64_C(1) << n_significand_bits) - 1); + } + + constexpr inline bool operator==(const cfloat_advanced& other) const + { + return !is_nan() && !other.is_nan() && // Neither operand is NaN + ((is_zero() && other.is_zero()) || (m_data == other.m_data)); + } + + constexpr inline bool operator!=(const cfloat_advanced& other) const + { + return !(*this == other); + } + + constexpr inline cfloat_advanced& operator+=(const cfloat_advanced& rhs) + { + this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); + return *this; + } + +private: + storage_t m_data = 0; +}; + +// This should probably be exported so we can use it elsewhere +#undef BITCAST_CONSTEXPR + +/// \brief Wrapper to maintain API compatibility with older code, which was +/// limited to power-of-two sizes of floats. +template = true> +using cfloat = cfloat_advanced; + +namespace float_support +{ +// Pre-C++23 these can't be computed as constexpr, so have to hardcode +// them + +template +struct digits10; // floor(log10(2) * (digits - 1) +template +struct max_digits10; // ceil(log10(2) * digits + 1) +template +struct min_exponent10; // floor(log10(2) * min_exponent) +template +struct max_exponent10; // floor(log10(2) * max_exponent) + +template <> +struct digits10<8> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<8> +{ + constexpr static inline int value = 4; +}; + +template <> +struct digits10<10> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<10> +{ + constexpr static inline int value = 5; +}; + +template <> +struct digits10<24> +{ + constexpr static inline int value = 6; +}; + +template <> +struct max_digits10<24> +{ + constexpr static inline int value = 9; +}; + +template <> +struct min_exponent10<-13> +{ + constexpr static inline int value = -3; +}; + +template <> +struct max_exponent10<16> +{ + constexpr static inline int value = 4; +}; + +template <> +struct min_exponent10<-125> +{ + constexpr static inline int value = -37; +}; + +template <> +struct max_exponent10<128> +{ + constexpr static inline int value = 38; +}; + +template +inline constexpr int digits10_v = digits10::value; +template +inline constexpr int max_digits10_v = max_digits10::value; + +template +inline constexpr int min_exponent10_v = min_exponent10::value; + +template +inline constexpr int max_exponent10_v = max_exponent10::value; + +} // namespace float_support + +} // namespace ct + +namespace std +{ + +template +struct is_floating_point> : std::integral_constant +{}; + +template +class numeric_limits> +{ + using this_cfloat = ct::cfloat_advanced; + +public: + static constexpr bool is_specialized = true; + + static constexpr auto min() noexcept + { + return this_cfloat::from_bits(false, 1, 0); + } + + static constexpr auto max() noexcept + { + return this_cfloat::max(false); + } + static constexpr auto lowest() noexcept + { + return -max(); + } + + static constexpr int digits = this_cfloat::n_significand_bits + 1; + static constexpr int digits10 = ct::float_support::digits10_v; + static constexpr int max_digits10 = ct::float_support::max_digits10_v; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + + static constexpr auto epsilon() noexcept + { + return this_cfloat::from_bits(false, this_cfloat::exponent_bias - this_cfloat::n_significand_bits, 0); + } + + static constexpr auto round_error() noexcept + { + return this_cfloat::from_bits(0, this_cfloat::exponent_bias - 1, 0); + } + + static constexpr int min_exponent = (1 - this_cfloat::exponent_bias) + 1; + static constexpr int min_exponent10 = ct::float_support::min_exponent10_v; + static constexpr int max_exponent = this_cfloat::exponent_bias + 1; + static constexpr int max_exponent10 = ct::float_support::max_exponent10_v; + + static constexpr bool has_infinity = this_cfloat::has_inf; + static constexpr bool has_quiet_NaN = this_cfloat::has_nan && this_cfloat::has_inf; + static constexpr bool has_signaling_NaN = this_cfloat::has_nan; + static constexpr float_denorm_style has_denorm = this_cfloat::has_denorms ? denorm_present : denorm_absent; + static constexpr bool has_denorm_loss = false; + + static constexpr auto infinity() noexcept + { + if constexpr (this_cfloat::has_inf) + { + return this_cfloat::infinity(false); + } + else + { + return this_cfloat::from_bits(false, 0, 0); + } + } + + static constexpr auto quiet_NaN() noexcept + { + const uint64_t exp_bits = (UINT64_C(1) << this_cfloat::n_exponent_bits) - 1; + const uint64_t sig_bits = this_cfloat::has_inf ? (UINT64_C(1) << (this_cfloat::n_significand_bits - 1)) | 1 + : (UINT64_C(1) << this_cfloat::n_significand_bits) - 1; + return this_cfloat::from_bits(false, exp_bits, sig_bits); + } + + static constexpr auto signaling_NaN() noexcept + { + const uint64_t exp_bits = (UINT64_C(1) << this_cfloat::n_exponent_bits) - 1; + const uint64_t sig_bits = this_cfloat::has_inf ? 1 : (UINT64_C(1) << this_cfloat::n_significand_bits) - 1; + return this_cfloat::from_bits(false, exp_bits, sig_bits); + } + + static constexpr auto denorm_min() noexcept + { + return this_cfloat::from_bits(false, 0, 1); + } + + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std + +#endif // CT_CFLOAT_H diff --git a/include/float_utils.h b/include/float_utils.h deleted file mode 100644 index 831ad74..0000000 --- a/include/float_utils.h +++ /dev/null @@ -1,533 +0,0 @@ -// Copyright (c) 2024, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TOSA_FLOAT_UTILS_H_ -#define TOSA_FLOAT_UTILS_H_ - -#include -#include -#include -#include -#if defined(__cpp_lib_bit_cast) -#include -#endif // defined(__cpp_lib_bit_cast) - -namespace tosa -{ - -namespace float_support -{ - -struct hidden -{}; - -#if defined(__cpp_lib_bit_cast) -#define BITCAST_CONSTEXPR constexpr inline - -constexpr inline int32_t get_bits(const float& f) -{ - return std::bit_cast(f); -} -constexpr inline float from_bits(const int32_t& i) -{ - return std::bit_cast(i); -} - -#else -#define BITCAST_CONSTEXPR inline - -union ufloat32 -{ - constexpr ufloat32(const float& x) - : f(x) - {} - constexpr ufloat32(const int32_t& x) - : i(x) - {} - - float f; - int32_t i; -}; - -inline int32_t get_bits(const float& f) -{ - return ufloat32(f).i; -} -inline float from_bits(const int32_t& i) -{ - return ufloat32(i).f; -} -#endif - -} // namespace float_support - -template = true> -class float_t -{ - storage_t m_data = 0; - -public: - static constexpr size_t n_exponent_bits = n_exp_bits; - static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; - static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; - - /// \brief Construct a floating point type with the given bit - /// representation. - static constexpr float_t from_bits(storage_t bits) - { - return float_t(float_support::hidden(), bits); - } - - /// \brief Construct a float from the given sign, exponent and significand - /// bits. - static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) - { - storage_t bits = pm ? 1 : 0; - - bits <<= n_exp_bits; - bits |= e; - - bits <<= n_significand_bits; - if (with_denorm || e) - bits |= s; - - return float_t(float_support::hidden(), bits); - } - - /// \brief (Hidden) Construct a float type from a given bit pattern - constexpr float_t(const float_support::hidden&, storage_t bits) - : m_data(bits) - {} - - constexpr float_t() - : m_data(0) - {} - constexpr float_t(const float_t& other) - : m_data(other.m_data) - {} - - /// \brief Cast to a different floating point representation. - template - constexpr inline - operator float_t() const - { - using other_float_t = - float_t; - - // Shortcut for types which are fundamentally similar (e.g., bf16 -> - // fp32) - if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && - has_nan == other_has_nan) - { - return other_float_t::from_bits(static_cast(m_data) - << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); - } - - // Get initial values for the new floating point type - const bool sign_bit = m_data < 0; - int64_t new_exponent_bits = 0; - uint64_t new_significand = 0; - - if (is_nan() || is_infinity()) - { - new_exponent_bits = (1 << other_n_exp_bits) - 1; - - if (is_nan()) - { - if constexpr (other_has_infinity) - { - // Copy across the `not_quiet bit`; set the LSB. Don't - // attempt to copy across any of the rest of the payload. - new_significand = - 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); - } - else - { - new_significand = (1ul << other_float_t::n_significand_bits) - 1; - } - } - else if constexpr (!other_has_infinity) - { - new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); - } - } - else if (!is_zero()) - { - const int64_t this_exponent_bits = exponent_bits(); - { - constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; - new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); - } - new_significand = this->significand() << (64 - n_significand_bits); - - // Normalise subnormals - if (this_exponent_bits == 0) - { - // Shift the most-significant 1 out of the magnitude to convert - // it to a significand. Modify the exponent accordingly. - uint8_t shift = __builtin_clzl(new_significand) + 1; - new_exponent_bits -= shift; - new_significand <<= shift; - } - - // Align the significand for the output type - uint32_t shift = 64 - other_float_t::n_significand_bits; - const bool other_is_subnormal = new_exponent_bits <= 0; - if (other_is_subnormal) - { - shift += 1 - new_exponent_bits; - new_exponent_bits = 0; - } - - const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); - new_significand = shift == 64 ? 0 : new_significand >> shift; - - // Reinsert the most-significant-one if this is a subnormal in the - // output type. - new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); - - // Apply rounding based on the bits shifted out of the significand - const uint64_t shift_half = 1ll << (shift - 1); - if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) - { - new_significand += 1; - - // Handle the case that the significand overflowed due to - // rounding - constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; - if (new_significand > max_significand) - { - new_significand = 0; - new_exponent_bits++; - } - } - - // Saturate to infinity if the exponent is larger than can be - // represented in the output type. This can only occur if the size - // of the exponent of the new type is not greater than the exponent - // of the old type. - if constexpr (other_n_exp_bits <= n_exp_bits) - { - constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; - if (new_exponent_bits >= inf_exp_bits) - { - new_exponent_bits = inf_exp_bits; - new_significand = - other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); - } - } - } - - return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); - } - - /// \brief Convert from a 32-bit floating point value - BITCAST_CONSTEXPR - float_t(const float& f) - { - // If this format exactly represents the binary32 format then get - // the bits from the provided float; otherwise get a binary32 - // representation and then convert to this format. - if constexpr (represents_binary32()) - m_data = float_support::get_bits(f); - else - m_data = static_cast>( - static_cast>(f)) - .m_data; - } - - /// \brief Cast to a 32-bit floating point value - BITCAST_CONSTEXPR operator float() const - { - // If this format exactly represents the binary32 format then return - // a float; otherwise get a binary32 representation and then return - // a float. - if constexpr (represents_binary32()) - return float_support::from_bits(m_data); - else - return static_cast(this->operator float_t()); - } - - /// \brief Return whether this type represents the IEEE754 binary32 - /// format - constexpr static inline bool represents_binary32() - { - return std::is_same_v && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; - } - - constexpr auto operator-() const - { - return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); - } - - constexpr bool is_subnormal() const - { - return exponent_bits() == 0 && significand() != 0; - } - - constexpr bool is_zero() const - { - return exponent_bits() == 0 && significand() == 0; - } - - constexpr bool is_nan() const - { - return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && - ((with_infinity && significand()) || - (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); - } - - constexpr bool is_infinity() const - { - return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); - } - - constexpr inline const storage_t& bits() const - { - return m_data; - } - - /// \brief Get the exponent - constexpr inline int64_t exponent() const - { - return std::max(exponent_bits(), 1ul) - exponent_bias; - } - - /// \brief Get the bits from the exponent - constexpr inline uint64_t exponent_bits() const - { - constexpr uint64_t mask = (1ul << n_exp_bits) - 1; - return (m_data >> n_significand_bits) & mask; - } - - constexpr inline uint64_t significand() const - { - return m_data & ((1ul << n_significand_bits) - 1); - } - - constexpr inline bool operator==(const float_t& other) const - { - return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); - } - - constexpr inline float_t& operator+=(const float_t& rhs) - { - this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); - return *this; - } -}; - -// This should probably be exported so we can use it elsewhere -#undef BITCAST_CONSTEXPR - -namespace float_support -{ - -// Pre-C++23 these can't be computed as constexpr, so have to hardcode them - -template -struct digits10; // floor(log10(2) * (digits - 1) -template -struct max_digits10; // ceil(log10(2) * digits + 1) -template -struct min_exponent10; // floor(log10(2) * min_exponent) -template -struct max_exponent10; // floor(log10(2) * max_exponent) - -template <> -struct digits10<8> -{ - constexpr static inline int value = 2; -}; - -template <> -struct max_digits10<8> -{ - constexpr static inline int value = 4; -}; - -template <> -struct digits10<10> -{ - constexpr static inline int value = 2; -}; - -template <> -struct max_digits10<10> -{ - constexpr static inline int value = 5; -}; - -template <> -struct digits10<24> -{ - constexpr static inline int value = 6; -}; - -template <> -struct max_digits10<24> -{ - constexpr static inline int value = 9; -}; - -template <> -struct min_exponent10<-13> -{ - constexpr static inline int value = -3; -}; - -template <> -struct max_exponent10<16> -{ - constexpr static inline int value = 4; -}; - -template <> -struct min_exponent10<-125> -{ - constexpr static inline int value = -37; -}; - -template <> -struct max_exponent10<128> -{ - constexpr static inline int value = 38; -}; - -template -inline constexpr int digits10_v = digits10::value; -template -inline constexpr int max_digits10_v = max_digits10::value; - -template -inline constexpr int min_exponent10_v = min_exponent10::value; - -template -inline constexpr int max_exponent10_v = max_exponent10::value; - -} // namespace float_support - -} // namespace tosa - -namespace std -{ - -template -struct is_floating_point> - : std::integral_constant -{}; - -template -class numeric_limits> -{ - using this_float_t = tosa::float_t; - -public: - static constexpr bool is_specialized = true; - - static constexpr auto min() noexcept - { - return this_float_t::from_bits(false, 1, 0); - } - - static constexpr auto max() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, - (1 << this_float_t::n_significand_bits) - 1); - } - - static constexpr auto lowest() noexcept - { - return -max(); - } - - static constexpr int digits = this_float_t::n_significand_bits + 1; - static constexpr int digits10 = tosa::float_support::digits10_v; - static constexpr int max_digits10 = tosa::float_support::max_digits10_v; - - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr int radix = 2; - - static constexpr auto epsilon() noexcept - { - return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); - } - - static constexpr auto round_error() noexcept - { - return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); - } - - static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; - static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v; - static constexpr int max_exponent = this_float_t::exponent_bias + 1; - static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v; - - static constexpr bool has_infinity = with_inf; - static constexpr bool has_quiet_NaN = has_nan; - static constexpr bool has_signaling_NaN = true; - static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; - static constexpr bool has_denorm_loss = false; - - static constexpr auto infinity() noexcept - { - if constexpr (with_inf) - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); - } - else - { - return this_float_t::from_bits(false, 0, 0); - } - } - - static constexpr auto quiet_NaN() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, - 1 << (this_float_t::n_significand_bits - 1) | 1); - } - - static constexpr auto signaling_NaN() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); - } - - static constexpr auto denorm_min() noexcept - { - return this_float_t::from_bits(false, 0, 1); - } - - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = false; - static constexpr bool is_modulo = false; - - static constexpr bool traps = false; - static constexpr bool tinyness_before = false; - static constexpr float_round_style round_style = round_to_nearest; -}; - -} // namespace std - -#endif // TOSA_FLOAT_UTILS_H_ diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index f5f9e58..1f8310e 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -16,9 +16,9 @@ #ifndef _TOSA_SERIALIZATION_HANDLER_H #define _TOSA_SERIALIZATION_HANDLER_H #include "attribute.h" +#include "cfloat.h" #include "flatbuffers/idl.h" #include "flatbuffers/util.h" -#include "float_utils.h" #include "numpy_utils.h" #include "tosa_generated.h" #include diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 85625cd..0ce6211 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -19,8 +19,8 @@ #include using namespace tosa; -using fp8e4m3 = tosa::float_t; -using fp8e5m2 = tosa::float_t; +using fp8e4m3 = ct::cfloat; +using fp8e5m2 = ct::cfloat; TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector* shape, -- cgit v1.2.1