From ce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 21 Mar 2024 17:01:14 +0000 Subject: Add conversions of U8 to/from BF16 and FP8 Adds type to PadAttribute and ClampAttribute so their pad_const and max_val/min_val can be deserialized according to type Adds conversion functions of U8 arrays to/from BF16/FP8 values Also, refactor and expose TosaSerializer.convertDataToUint8Vec for converting dtype/data to uint8 list for serialization And modify convertDataToUint8Vec to serialize bf16 values into 2 bytes each, and serialize fp8 values into single bytes each. Signed-off-by: Tai Ly Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f --- include/attribute.def | 10 +- include/float_utils.h | 533 +++++++++++++++++++++++++++++++++++ include/tosa_generated.h | 40 ++- include/tosa_serialization_handler.h | 7 + 4 files changed, 578 insertions(+), 12 deletions(-) create mode 100644 include/float_utils.h (limited to 'include') diff --git a/include/attribute.def b/include/attribute.def index 723543e..30b432d 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -52,8 +52,9 @@ DEF_ATTRIBUTE(TransposeConv, 7, bool, S, local_bound, DType, S, acc_type) -DEF_ATTRIBUTE(Pad, 1, - uint8_t, V, pad_const) +DEF_ATTRIBUTE(Pad, 2, + uint8_t, V, pad_const, + DType, S, type) DEF_ATTRIBUTE(Axis, 1, int32_t, S, axis) @@ -64,9 +65,10 @@ DEF_ATTRIBUTE(Resize, 4, int16_t, V, border, ResizeMode, S, mode) -DEF_ATTRIBUTE(Clamp, 2, +DEF_ATTRIBUTE(Clamp, 3, uint8_t, V, min_val, - uint8_t, V, max_val) + uint8_t, V, max_val, + DType, S, type) DEF_ATTRIBUTE(Rescale, 7, int32_t, S, input_zp, diff --git a/include/float_utils.h b/include/float_utils.h new file mode 100644 index 0000000..831ad74 --- /dev/null +++ b/include/float_utils.h @@ -0,0 +1,533 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_FLOAT_UTILS_H_ +#define TOSA_FLOAT_UTILS_H_ + +#include +#include +#include +#include +#if defined(__cpp_lib_bit_cast) +#include +#endif // defined(__cpp_lib_bit_cast) + +namespace tosa +{ + +namespace float_support +{ + +struct hidden +{}; + +#if defined(__cpp_lib_bit_cast) +#define BITCAST_CONSTEXPR constexpr inline + +constexpr inline int32_t get_bits(const float& f) +{ + return std::bit_cast(f); +} +constexpr inline float from_bits(const int32_t& i) +{ + return std::bit_cast(i); +} + +#else +#define BITCAST_CONSTEXPR inline + +union ufloat32 +{ + constexpr ufloat32(const float& x) + : f(x) + {} + constexpr ufloat32(const int32_t& x) + : i(x) + {} + + float f; + int32_t i; +}; + +inline int32_t get_bits(const float& f) +{ + return ufloat32(f).i; +} +inline float from_bits(const int32_t& i) +{ + return ufloat32(i).f; +} +#endif + +} // namespace float_support + +template = true> +class float_t +{ + storage_t m_data = 0; + +public: + static constexpr size_t n_exponent_bits = n_exp_bits; + static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; + static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; + + /// \brief Construct a floating point type with the given bit + /// representation. + static constexpr float_t from_bits(storage_t bits) + { + return float_t(float_support::hidden(), bits); + } + + /// \brief Construct a float from the given sign, exponent and significand + /// bits. + static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) + { + storage_t bits = pm ? 1 : 0; + + bits <<= n_exp_bits; + bits |= e; + + bits <<= n_significand_bits; + if (with_denorm || e) + bits |= s; + + return float_t(float_support::hidden(), bits); + } + + /// \brief (Hidden) Construct a float type from a given bit pattern + constexpr float_t(const float_support::hidden&, storage_t bits) + : m_data(bits) + {} + + constexpr float_t() + : m_data(0) + {} + constexpr float_t(const float_t& other) + : m_data(other.m_data) + {} + + /// \brief Cast to a different floating point representation. + template + constexpr inline + operator float_t() const + { + using other_float_t = + float_t; + + // Shortcut for types which are fundamentally similar (e.g., bf16 -> + // fp32) + if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && + has_nan == other_has_nan) + { + return other_float_t::from_bits(static_cast(m_data) + << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); + } + + // Get initial values for the new floating point type + const bool sign_bit = m_data < 0; + int64_t new_exponent_bits = 0; + uint64_t new_significand = 0; + + if (is_nan() || is_infinity()) + { + new_exponent_bits = (1 << other_n_exp_bits) - 1; + + if (is_nan()) + { + if constexpr (other_has_infinity) + { + // Copy across the `not_quiet bit`; set the LSB. Don't + // attempt to copy across any of the rest of the payload. + new_significand = + 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); + } + else + { + new_significand = (1ul << other_float_t::n_significand_bits) - 1; + } + } + else if constexpr (!other_has_infinity) + { + new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + else if (!is_zero()) + { + const int64_t this_exponent_bits = exponent_bits(); + { + constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; + new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); + } + new_significand = this->significand() << (64 - n_significand_bits); + + // Normalise subnormals + if (this_exponent_bits == 0) + { + // Shift the most-significant 1 out of the magnitude to convert + // it to a significand. Modify the exponent accordingly. + uint8_t shift = __builtin_clzl(new_significand) + 1; + new_exponent_bits -= shift; + new_significand <<= shift; + } + + // Align the significand for the output type + uint32_t shift = 64 - other_float_t::n_significand_bits; + const bool other_is_subnormal = new_exponent_bits <= 0; + if (other_is_subnormal) + { + shift += 1 - new_exponent_bits; + new_exponent_bits = 0; + } + + const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); + new_significand = shift == 64 ? 0 : new_significand >> shift; + + // Reinsert the most-significant-one if this is a subnormal in the + // output type. + new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); + + // Apply rounding based on the bits shifted out of the significand + const uint64_t shift_half = 1ll << (shift - 1); + if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) + { + new_significand += 1; + + // Handle the case that the significand overflowed due to + // rounding + constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; + if (new_significand > max_significand) + { + new_significand = 0; + new_exponent_bits++; + } + } + + // Saturate to infinity if the exponent is larger than can be + // represented in the output type. This can only occur if the size + // of the exponent of the new type is not greater than the exponent + // of the old type. + if constexpr (other_n_exp_bits <= n_exp_bits) + { + constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; + if (new_exponent_bits >= inf_exp_bits) + { + new_exponent_bits = inf_exp_bits; + new_significand = + other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + } + + return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); + } + + /// \brief Convert from a 32-bit floating point value + BITCAST_CONSTEXPR + float_t(const float& f) + { + // If this format exactly represents the binary32 format then get + // the bits from the provided float; otherwise get a binary32 + // representation and then convert to this format. + if constexpr (represents_binary32()) + m_data = float_support::get_bits(f); + else + m_data = static_cast>( + static_cast>(f)) + .m_data; + } + + /// \brief Cast to a 32-bit floating point value + BITCAST_CONSTEXPR operator float() const + { + // If this format exactly represents the binary32 format then return + // a float; otherwise get a binary32 representation and then return + // a float. + if constexpr (represents_binary32()) + return float_support::from_bits(m_data); + else + return static_cast(this->operator float_t()); + } + + /// \brief Return whether this type represents the IEEE754 binary32 + /// format + constexpr static inline bool represents_binary32() + { + return std::is_same_v && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; + } + + constexpr auto operator-() const + { + return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); + } + + constexpr bool is_subnormal() const + { + return exponent_bits() == 0 && significand() != 0; + } + + constexpr bool is_zero() const + { + return exponent_bits() == 0 && significand() == 0; + } + + constexpr bool is_nan() const + { + return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && + ((with_infinity && significand()) || + (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); + } + + constexpr bool is_infinity() const + { + return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); + } + + constexpr inline const storage_t& bits() const + { + return m_data; + } + + /// \brief Get the exponent + constexpr inline int64_t exponent() const + { + return std::max(exponent_bits(), 1ul) - exponent_bias; + } + + /// \brief Get the bits from the exponent + constexpr inline uint64_t exponent_bits() const + { + constexpr uint64_t mask = (1ul << n_exp_bits) - 1; + return (m_data >> n_significand_bits) & mask; + } + + constexpr inline uint64_t significand() const + { + return m_data & ((1ul << n_significand_bits) - 1); + } + + constexpr inline bool operator==(const float_t& other) const + { + return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); + } + + constexpr inline float_t& operator+=(const float_t& rhs) + { + this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); + return *this; + } +}; + +// This should probably be exported so we can use it elsewhere +#undef BITCAST_CONSTEXPR + +namespace float_support +{ + +// Pre-C++23 these can't be computed as constexpr, so have to hardcode them + +template +struct digits10; // floor(log10(2) * (digits - 1) +template +struct max_digits10; // ceil(log10(2) * digits + 1) +template +struct min_exponent10; // floor(log10(2) * min_exponent) +template +struct max_exponent10; // floor(log10(2) * max_exponent) + +template <> +struct digits10<8> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<8> +{ + constexpr static inline int value = 4; +}; + +template <> +struct digits10<10> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<10> +{ + constexpr static inline int value = 5; +}; + +template <> +struct digits10<24> +{ + constexpr static inline int value = 6; +}; + +template <> +struct max_digits10<24> +{ + constexpr static inline int value = 9; +}; + +template <> +struct min_exponent10<-13> +{ + constexpr static inline int value = -3; +}; + +template <> +struct max_exponent10<16> +{ + constexpr static inline int value = 4; +}; + +template <> +struct min_exponent10<-125> +{ + constexpr static inline int value = -37; +}; + +template <> +struct max_exponent10<128> +{ + constexpr static inline int value = 38; +}; + +template +inline constexpr int digits10_v = digits10::value; +template +inline constexpr int max_digits10_v = max_digits10::value; + +template +inline constexpr int min_exponent10_v = min_exponent10::value; + +template +inline constexpr int max_exponent10_v = max_exponent10::value; + +} // namespace float_support + +} // namespace tosa + +namespace std +{ + +template +struct is_floating_point> + : std::integral_constant +{}; + +template +class numeric_limits> +{ + using this_float_t = tosa::float_t; + +public: + static constexpr bool is_specialized = true; + + static constexpr auto min() noexcept + { + return this_float_t::from_bits(false, 1, 0); + } + + static constexpr auto max() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, + (1 << this_float_t::n_significand_bits) - 1); + } + + static constexpr auto lowest() noexcept + { + return -max(); + } + + static constexpr int digits = this_float_t::n_significand_bits + 1; + static constexpr int digits10 = tosa::float_support::digits10_v; + static constexpr int max_digits10 = tosa::float_support::max_digits10_v; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + + static constexpr auto epsilon() noexcept + { + return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); + } + + static constexpr auto round_error() noexcept + { + return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); + } + + static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; + static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v; + static constexpr int max_exponent = this_float_t::exponent_bias + 1; + static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v; + + static constexpr bool has_infinity = with_inf; + static constexpr bool has_quiet_NaN = has_nan; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; + static constexpr bool has_denorm_loss = false; + + static constexpr auto infinity() noexcept + { + if constexpr (with_inf) + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); + } + else + { + return this_float_t::from_bits(false, 0, 0); + } + } + + static constexpr auto quiet_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, + 1 << (this_float_t::n_significand_bits - 1) | 1); + } + + static constexpr auto signaling_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); + } + + static constexpr auto denorm_min() noexcept + { + return this_float_t::from_bits(false, 0, 1); + } + + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std + +#endif // TOSA_FLOAT_UTILS_H_ diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 20f6993..0798256 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -1008,15 +1008,20 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef PadAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_PAD_CONST = 4 + VT_PAD_CONST = 4, + VT_TYPE = 6 }; const ::flatbuffers::Vector *pad_const() const { return GetPointer *>(VT_PAD_CONST); } + tosa::DType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD_CONST) && verifier.VerifyVector(pad_const()) && + VerifyField(verifier, VT_TYPE, 4) && verifier.EndTable(); } }; @@ -1028,6 +1033,9 @@ struct PadAttributeBuilder { void add_pad_const(::flatbuffers::Offset<::flatbuffers::Vector> pad_const) { fbb_.AddOffset(PadAttribute::VT_PAD_CONST, pad_const); } + void add_type(tosa::DType type) { + fbb_.AddElement(PadAttribute::VT_TYPE, static_cast(type), 0); + } explicit PadAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1041,20 +1049,24 @@ struct PadAttributeBuilder { inline ::flatbuffers::Offset CreatePadAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector> pad_const = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector> pad_const = 0, + tosa::DType type = tosa::DType_UNKNOWN) { PadAttributeBuilder builder_(_fbb); + builder_.add_type(type); builder_.add_pad_const(pad_const); return builder_.Finish(); } inline ::flatbuffers::Offset CreatePadAttributeDirect( ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *pad_const = nullptr) { + const std::vector *pad_const = nullptr, + tosa::DType type = tosa::DType_UNKNOWN) { if (pad_const) { _fbb.ForceVectorAlignment(pad_const->size(), sizeof(uint8_t), 8); } auto pad_const__ = pad_const ? _fbb.CreateVector(*pad_const) : 0; return tosa::CreatePadAttribute( _fbb, - pad_const__); + pad_const__, + type); } struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1193,7 +1205,8 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef ClampAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_MIN_VAL = 4, - VT_MAX_VAL = 6 + VT_MAX_VAL = 6, + VT_TYPE = 8 }; const ::flatbuffers::Vector *min_val() const { return GetPointer *>(VT_MIN_VAL); @@ -1201,12 +1214,16 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector *max_val() const { return GetPointer *>(VT_MAX_VAL); } + tosa::DType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN_VAL) && verifier.VerifyVector(min_val()) && VerifyOffset(verifier, VT_MAX_VAL) && verifier.VerifyVector(max_val()) && + VerifyField(verifier, VT_TYPE, 4) && verifier.EndTable(); } }; @@ -1221,6 +1238,9 @@ struct ClampAttributeBuilder { void add_max_val(::flatbuffers::Offset<::flatbuffers::Vector> max_val) { fbb_.AddOffset(ClampAttribute::VT_MAX_VAL, max_val); } + void add_type(tosa::DType type) { + fbb_.AddElement(ClampAttribute::VT_TYPE, static_cast(type), 0); + } explicit ClampAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1235,8 +1255,10 @@ struct ClampAttributeBuilder { inline ::flatbuffers::Offset CreateClampAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, ::flatbuffers::Offset<::flatbuffers::Vector> min_val = 0, - ::flatbuffers::Offset<::flatbuffers::Vector> max_val = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector> max_val = 0, + tosa::DType type = tosa::DType_UNKNOWN) { ClampAttributeBuilder builder_(_fbb); + builder_.add_type(type); builder_.add_max_val(max_val); builder_.add_min_val(min_val); return builder_.Finish(); @@ -1245,7 +1267,8 @@ inline ::flatbuffers::Offset CreateClampAttribute( inline ::flatbuffers::Offset CreateClampAttributeDirect( ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector *min_val = nullptr, - const std::vector *max_val = nullptr) { + const std::vector *max_val = nullptr, + tosa::DType type = tosa::DType_UNKNOWN) { if (min_val) { _fbb.ForceVectorAlignment(min_val->size(), sizeof(uint8_t), 8); } auto min_val__ = min_val ? _fbb.CreateVector(*min_val) : 0; if (max_val) { _fbb.ForceVectorAlignment(max_val->size(), sizeof(uint8_t), 8); } @@ -1253,7 +1276,8 @@ inline ::flatbuffers::Offset CreateClampAttributeDirect( return tosa::CreateClampAttribute( _fbb, min_val__, - max_val__); + max_val__, + type); } struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 5c53f57..f5f9e58 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -18,6 +18,7 @@ #include "attribute.h" #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "float_utils.h" #include "numpy_utils.h" #include "tosa_generated.h" #include @@ -411,6 +412,9 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. + static tosa_err_t ConvertBF16toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF16toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF32toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI64toU8(const std::vector& in, std::vector& out); @@ -421,6 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertBooltoU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertU8toBF16(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF16(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF32(const std::vector& in, uint32_t out_size, std::vector& out); -- cgit v1.2.1