From ce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 21 Mar 2024 17:01:14 +0000 Subject: Add conversions of U8 to/from BF16 and FP8 Adds type to PadAttribute and ClampAttribute so their pad_const and max_val/min_val can be deserialized according to type Adds conversion functions of U8 arrays to/from BF16/FP8 values Also, refactor and expose TosaSerializer.convertDataToUint8Vec for converting dtype/data to uint8 list for serialization And modify convertDataToUint8Vec to serialize bf16 values into 2 bytes each, and serialize fp8 values into single bytes each. Signed-off-by: Tai Ly Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f --- CMakeLists.txt | 5 +- include/attribute.def | 10 +- include/float_utils.h | 533 +++++++++++++++++++++++++++++++++++ include/tosa_generated.h | 40 ++- include/tosa_serialization_handler.h | 7 + python/serializer/tosa_serializer.py | 193 +++++++------ python/tosa/ClampAttribute.py | 15 +- python/tosa/PadAttribute.py | 15 +- schema/tosa.fbs | 2 + src/tosa_serialization_handler.cpp | 120 ++++++++ 10 files changed, 841 insertions(+), 99 deletions(-) create mode 100644 include/float_utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ac34b75..5f4f851 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,8 +19,8 @@ cmake_minimum_required(VERSION 3.13.4) project(TosaSerialization) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to") -set(CMAKE_CXX_STANDARD_REQUIRED YES) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_VERBOSE_MAKEFILE ON) @@ -76,6 +76,7 @@ set(public_headers) list(APPEND public_headers include/attribute.h include/attribute.def + include/float_utils.h include/numpy_utils.h include/tosa_generated.h include/tosa_serialization_handler.h diff --git a/include/attribute.def b/include/attribute.def index 723543e..30b432d 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -52,8 +52,9 @@ DEF_ATTRIBUTE(TransposeConv, 7, bool, S, local_bound, DType, S, acc_type) -DEF_ATTRIBUTE(Pad, 1, - uint8_t, V, pad_const) +DEF_ATTRIBUTE(Pad, 2, + uint8_t, V, pad_const, + DType, S, type) DEF_ATTRIBUTE(Axis, 1, int32_t, S, axis) @@ -64,9 +65,10 @@ DEF_ATTRIBUTE(Resize, 4, int16_t, V, border, ResizeMode, S, mode) -DEF_ATTRIBUTE(Clamp, 2, +DEF_ATTRIBUTE(Clamp, 3, uint8_t, V, min_val, - uint8_t, V, max_val) + uint8_t, V, max_val, + DType, S, type) DEF_ATTRIBUTE(Rescale, 7, int32_t, S, input_zp, diff --git a/include/float_utils.h b/include/float_utils.h new file mode 100644 index 0000000..831ad74 --- /dev/null +++ b/include/float_utils.h @@ -0,0 +1,533 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_FLOAT_UTILS_H_ +#define TOSA_FLOAT_UTILS_H_ + +#include +#include +#include +#include +#if defined(__cpp_lib_bit_cast) +#include +#endif // defined(__cpp_lib_bit_cast) + +namespace tosa +{ + +namespace float_support +{ + +struct hidden +{}; + +#if defined(__cpp_lib_bit_cast) +#define BITCAST_CONSTEXPR constexpr inline + +constexpr inline int32_t get_bits(const float& f) +{ + return std::bit_cast(f); +} +constexpr inline float from_bits(const int32_t& i) +{ + return std::bit_cast(i); +} + +#else +#define BITCAST_CONSTEXPR inline + +union ufloat32 +{ + constexpr ufloat32(const float& x) + : f(x) + {} + constexpr ufloat32(const int32_t& x) + : i(x) + {} + + float f; + int32_t i; +}; + +inline int32_t get_bits(const float& f) +{ + return ufloat32(f).i; +} +inline float from_bits(const int32_t& i) +{ + return ufloat32(i).f; +} +#endif + +} // namespace float_support + +template = true> +class float_t +{ + storage_t m_data = 0; + +public: + static constexpr size_t n_exponent_bits = n_exp_bits; + static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; + static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; + + /// \brief Construct a floating point type with the given bit + /// representation. + static constexpr float_t from_bits(storage_t bits) + { + return float_t(float_support::hidden(), bits); + } + + /// \brief Construct a float from the given sign, exponent and significand + /// bits. + static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) + { + storage_t bits = pm ? 1 : 0; + + bits <<= n_exp_bits; + bits |= e; + + bits <<= n_significand_bits; + if (with_denorm || e) + bits |= s; + + return float_t(float_support::hidden(), bits); + } + + /// \brief (Hidden) Construct a float type from a given bit pattern + constexpr float_t(const float_support::hidden&, storage_t bits) + : m_data(bits) + {} + + constexpr float_t() + : m_data(0) + {} + constexpr float_t(const float_t& other) + : m_data(other.m_data) + {} + + /// \brief Cast to a different floating point representation. + template + constexpr inline + operator float_t() const + { + using other_float_t = + float_t; + + // Shortcut for types which are fundamentally similar (e.g., bf16 -> + // fp32) + if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && + has_nan == other_has_nan) + { + return other_float_t::from_bits(static_cast(m_data) + << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); + } + + // Get initial values for the new floating point type + const bool sign_bit = m_data < 0; + int64_t new_exponent_bits = 0; + uint64_t new_significand = 0; + + if (is_nan() || is_infinity()) + { + new_exponent_bits = (1 << other_n_exp_bits) - 1; + + if (is_nan()) + { + if constexpr (other_has_infinity) + { + // Copy across the `not_quiet bit`; set the LSB. Don't + // attempt to copy across any of the rest of the payload. + new_significand = + 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); + } + else + { + new_significand = (1ul << other_float_t::n_significand_bits) - 1; + } + } + else if constexpr (!other_has_infinity) + { + new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + else if (!is_zero()) + { + const int64_t this_exponent_bits = exponent_bits(); + { + constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; + new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); + } + new_significand = this->significand() << (64 - n_significand_bits); + + // Normalise subnormals + if (this_exponent_bits == 0) + { + // Shift the most-significant 1 out of the magnitude to convert + // it to a significand. Modify the exponent accordingly. + uint8_t shift = __builtin_clzl(new_significand) + 1; + new_exponent_bits -= shift; + new_significand <<= shift; + } + + // Align the significand for the output type + uint32_t shift = 64 - other_float_t::n_significand_bits; + const bool other_is_subnormal = new_exponent_bits <= 0; + if (other_is_subnormal) + { + shift += 1 - new_exponent_bits; + new_exponent_bits = 0; + } + + const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); + new_significand = shift == 64 ? 0 : new_significand >> shift; + + // Reinsert the most-significant-one if this is a subnormal in the + // output type. + new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); + + // Apply rounding based on the bits shifted out of the significand + const uint64_t shift_half = 1ll << (shift - 1); + if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) + { + new_significand += 1; + + // Handle the case that the significand overflowed due to + // rounding + constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; + if (new_significand > max_significand) + { + new_significand = 0; + new_exponent_bits++; + } + } + + // Saturate to infinity if the exponent is larger than can be + // represented in the output type. This can only occur if the size + // of the exponent of the new type is not greater than the exponent + // of the old type. + if constexpr (other_n_exp_bits <= n_exp_bits) + { + constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; + if (new_exponent_bits >= inf_exp_bits) + { + new_exponent_bits = inf_exp_bits; + new_significand = + other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + } + + return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); + } + + /// \brief Convert from a 32-bit floating point value + BITCAST_CONSTEXPR + float_t(const float& f) + { + // If this format exactly represents the binary32 format then get + // the bits from the provided float; otherwise get a binary32 + // representation and then convert to this format. + if constexpr (represents_binary32()) + m_data = float_support::get_bits(f); + else + m_data = static_cast>( + static_cast>(f)) + .m_data; + } + + /// \brief Cast to a 32-bit floating point value + BITCAST_CONSTEXPR operator float() const + { + // If this format exactly represents the binary32 format then return + // a float; otherwise get a binary32 representation and then return + // a float. + if constexpr (represents_binary32()) + return float_support::from_bits(m_data); + else + return static_cast(this->operator float_t()); + } + + /// \brief Return whether this type represents the IEEE754 binary32 + /// format + constexpr static inline bool represents_binary32() + { + return std::is_same_v && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; + } + + constexpr auto operator-() const + { + return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); + } + + constexpr bool is_subnormal() const + { + return exponent_bits() == 0 && significand() != 0; + } + + constexpr bool is_zero() const + { + return exponent_bits() == 0 && significand() == 0; + } + + constexpr bool is_nan() const + { + return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && + ((with_infinity && significand()) || + (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); + } + + constexpr bool is_infinity() const + { + return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); + } + + constexpr inline const storage_t& bits() const + { + return m_data; + } + + /// \brief Get the exponent + constexpr inline int64_t exponent() const + { + return std::max(exponent_bits(), 1ul) - exponent_bias; + } + + /// \brief Get the bits from the exponent + constexpr inline uint64_t exponent_bits() const + { + constexpr uint64_t mask = (1ul << n_exp_bits) - 1; + return (m_data >> n_significand_bits) & mask; + } + + constexpr inline uint64_t significand() const + { + return m_data & ((1ul << n_significand_bits) - 1); + } + + constexpr inline bool operator==(const float_t& other) const + { + return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); + } + + constexpr inline float_t& operator+=(const float_t& rhs) + { + this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); + return *this; + } +}; + +// This should probably be exported so we can use it elsewhere +#undef BITCAST_CONSTEXPR + +namespace float_support +{ + +// Pre-C++23 these can't be computed as constexpr, so have to hardcode them + +template +struct digits10; // floor(log10(2) * (digits - 1) +template +struct max_digits10; // ceil(log10(2) * digits + 1) +template +struct min_exponent10; // floor(log10(2) * min_exponent) +template +struct max_exponent10; // floor(log10(2) * max_exponent) + +template <> +struct digits10<8> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<8> +{ + constexpr static inline int value = 4; +}; + +template <> +struct digits10<10> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<10> +{ + constexpr static inline int value = 5; +}; + +template <> +struct digits10<24> +{ + constexpr static inline int value = 6; +}; + +template <> +struct max_digits10<24> +{ + constexpr static inline int value = 9; +}; + +template <> +struct min_exponent10<-13> +{ + constexpr static inline int value = -3; +}; + +template <> +struct max_exponent10<16> +{ + constexpr static inline int value = 4; +}; + +template <> +struct min_exponent10<-125> +{ + constexpr static inline int value = -37; +}; + +template <> +struct max_exponent10<128> +{ + constexpr static inline int value = 38; +}; + +template +inline constexpr int digits10_v = digits10::value; +template +inline constexpr int max_digits10_v = max_digits10::value; + +template +inline constexpr int min_exponent10_v = min_exponent10::value; + +template +inline constexpr int max_exponent10_v = max_exponent10::value; + +} // namespace float_support + +} // namespace tosa + +namespace std +{ + +template +struct is_floating_point> + : std::integral_constant +{}; + +template +class numeric_limits> +{ + using this_float_t = tosa::float_t; + +public: + static constexpr bool is_specialized = true; + + static constexpr auto min() noexcept + { + return this_float_t::from_bits(false, 1, 0); + } + + static constexpr auto max() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, + (1 << this_float_t::n_significand_bits) - 1); + } + + static constexpr auto lowest() noexcept + { + return -max(); + } + + static constexpr int digits = this_float_t::n_significand_bits + 1; + static constexpr int digits10 = tosa::float_support::digits10_v; + static constexpr int max_digits10 = tosa::float_support::max_digits10_v; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + + static constexpr auto epsilon() noexcept + { + return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); + } + + static constexpr auto round_error() noexcept + { + return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); + } + + static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; + static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v; + static constexpr int max_exponent = this_float_t::exponent_bias + 1; + static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v; + + static constexpr bool has_infinity = with_inf; + static constexpr bool has_quiet_NaN = has_nan; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; + static constexpr bool has_denorm_loss = false; + + static constexpr auto infinity() noexcept + { + if constexpr (with_inf) + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); + } + else + { + return this_float_t::from_bits(false, 0, 0); + } + } + + static constexpr auto quiet_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, + 1 << (this_float_t::n_significand_bits - 1) | 1); + } + + static constexpr auto signaling_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); + } + + static constexpr auto denorm_min() noexcept + { + return this_float_t::from_bits(false, 0, 1); + } + + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std + +#endif // TOSA_FLOAT_UTILS_H_ diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 20f6993..0798256 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -1008,15 +1008,20 @@ inline ::flatbuffers::Offset CreateTransposeConvAttribut struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef PadAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_PAD_CONST = 4 + VT_PAD_CONST = 4, + VT_TYPE = 6 }; const ::flatbuffers::Vector *pad_const() const { return GetPointer *>(VT_PAD_CONST); } + tosa::DType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD_CONST) && verifier.VerifyVector(pad_const()) && + VerifyField(verifier, VT_TYPE, 4) && verifier.EndTable(); } }; @@ -1028,6 +1033,9 @@ struct PadAttributeBuilder { void add_pad_const(::flatbuffers::Offset<::flatbuffers::Vector> pad_const) { fbb_.AddOffset(PadAttribute::VT_PAD_CONST, pad_const); } + void add_type(tosa::DType type) { + fbb_.AddElement(PadAttribute::VT_TYPE, static_cast(type), 0); + } explicit PadAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1041,20 +1049,24 @@ struct PadAttributeBuilder { inline ::flatbuffers::Offset CreatePadAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector> pad_const = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector> pad_const = 0, + tosa::DType type = tosa::DType_UNKNOWN) { PadAttributeBuilder builder_(_fbb); + builder_.add_type(type); builder_.add_pad_const(pad_const); return builder_.Finish(); } inline ::flatbuffers::Offset CreatePadAttributeDirect( ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *pad_const = nullptr) { + const std::vector *pad_const = nullptr, + tosa::DType type = tosa::DType_UNKNOWN) { if (pad_const) { _fbb.ForceVectorAlignment(pad_const->size(), sizeof(uint8_t), 8); } auto pad_const__ = pad_const ? _fbb.CreateVector(*pad_const) : 0; return tosa::CreatePadAttribute( _fbb, - pad_const__); + pad_const__, + type); } struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1193,7 +1205,8 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef ClampAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_MIN_VAL = 4, - VT_MAX_VAL = 6 + VT_MAX_VAL = 6, + VT_TYPE = 8 }; const ::flatbuffers::Vector *min_val() const { return GetPointer *>(VT_MIN_VAL); @@ -1201,12 +1214,16 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector *max_val() const { return GetPointer *>(VT_MAX_VAL); } + tosa::DType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN_VAL) && verifier.VerifyVector(min_val()) && VerifyOffset(verifier, VT_MAX_VAL) && verifier.VerifyVector(max_val()) && + VerifyField(verifier, VT_TYPE, 4) && verifier.EndTable(); } }; @@ -1221,6 +1238,9 @@ struct ClampAttributeBuilder { void add_max_val(::flatbuffers::Offset<::flatbuffers::Vector> max_val) { fbb_.AddOffset(ClampAttribute::VT_MAX_VAL, max_val); } + void add_type(tosa::DType type) { + fbb_.AddElement(ClampAttribute::VT_TYPE, static_cast(type), 0); + } explicit ClampAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1235,8 +1255,10 @@ struct ClampAttributeBuilder { inline ::flatbuffers::Offset CreateClampAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, ::flatbuffers::Offset<::flatbuffers::Vector> min_val = 0, - ::flatbuffers::Offset<::flatbuffers::Vector> max_val = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector> max_val = 0, + tosa::DType type = tosa::DType_UNKNOWN) { ClampAttributeBuilder builder_(_fbb); + builder_.add_type(type); builder_.add_max_val(max_val); builder_.add_min_val(min_val); return builder_.Finish(); @@ -1245,7 +1267,8 @@ inline ::flatbuffers::Offset CreateClampAttribute( inline ::flatbuffers::Offset CreateClampAttributeDirect( ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector *min_val = nullptr, - const std::vector *max_val = nullptr) { + const std::vector *max_val = nullptr, + tosa::DType type = tosa::DType_UNKNOWN) { if (min_val) { _fbb.ForceVectorAlignment(min_val->size(), sizeof(uint8_t), 8); } auto min_val__ = min_val ? _fbb.CreateVector(*min_val) : 0; if (max_val) { _fbb.ForceVectorAlignment(max_val->size(), sizeof(uint8_t), 8); } @@ -1253,7 +1276,8 @@ inline ::flatbuffers::Offset CreateClampAttributeDirect( return tosa::CreateClampAttribute( _fbb, min_val__, - max_val__); + max_val__, + type); } struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 5c53f57..f5f9e58 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -18,6 +18,7 @@ #include "attribute.h" #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "float_utils.h" #include "numpy_utils.h" #include "tosa_generated.h" #include @@ -411,6 +412,9 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. + static tosa_err_t ConvertBF16toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF16toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF32toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI64toU8(const std::vector& in, std::vector& out); @@ -421,6 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertBooltoU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertU8toBF16(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF16(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF32(const std::vector& in, uint32_t out_size, std::vector& out); diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index e6ab3d0..298907e 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -17,6 +17,7 @@ import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np +import struct from enum import IntEnum, unique from tosa import ( TosaGraph, @@ -204,7 +205,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.bools.append((a.AddLocalBound, local_bound)) self.ints.append((a.AddAccType, acc_type)) - def PadAttribute(self, serializer_builder, pad_const_val_as_bytes): + def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype): from tosa import PadAttribute as a, Attribute self.utype = Attribute.Attribute().PadAttribute @@ -216,6 +217,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): ) self.floats.append((a.AddPadConst, serialized_pad_const_val)) + self.ints.append((a.AddType, dtype)) def AxisAttribute(self, axis): from tosa import AxisAttribute as a, Attribute @@ -236,7 +238,9 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.int16vecs.append((a.AddBorder, border)) self.ints.append((a.AddMode, mode)) - def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes): + def ClampAttribute( + self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype + ): from tosa import ClampAttribute as a, Attribute self.utype = Attribute.Attribute().ClampAttribute @@ -252,6 +256,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.floats.append((a.AddMinVal, serialized_min_val)) self.floats.append((a.AddMaxVal, serialized_max_val)) + self.ints.append((a.AddType, dtype)) def RescaleAttribute( self, @@ -439,87 +444,7 @@ class TosaSerializerTensor: fb_name = builder.CreateString(self.name) fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape) if self.data: - u8_data = list() - # little endianess - if self.dtype == DType.BOOL: - for val in self.data: - val_u8 = np.uint8(val) - u8_data.append(val_u8) - elif self.dtype == DType.INT4: - in_size = len(self.data) - out_size = (in_size + 1) // 2 - for i in range(out_size): - val_0 = self.data[2 * i] - if (2 * i + 1) < in_size: - val_1 = self.data[2 * i + 1] - else: - val_1 = 0 - val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4) - val_u8 = np.uint8(val_i8) - u8_data.append(val_u8) - elif self.dtype == DType.INT8: - for val in self.data: - val_u8 = np.array(val).astype(dtype=np.uint8) - u8_data.append(val_u8) - elif self.dtype == DType.INT16: - for val in self.data: - val_u16 = np.array(val).astype(dtype=np.uint16) - b0 = val_u16 & ByteMask - b1 = (val_u16 >> np.uint16(8)) & ByteMask - u8_data.extend([b0, b1]) - elif self.dtype == DType.INT32: - for val in self.data: - val_u32 = np.array(val).astype(dtype=np.uint32) - b0 = val_u32 & ByteMask - b1 = (val_u32 >> np.uint32(8)) & ByteMask - b2 = (val_u32 >> np.uint32(16)) & ByteMask - b3 = (val_u32 >> np.uint32(24)) & ByteMask - u8_data.extend([b0, b1, b2, b3]) - elif self.dtype == DType.INT48: - for val in self.data: - val_u64 = np.uint64(val) - b0 = val_u64 & ByteMask - b1 = (val_u64 >> np.uint64(8)) & ByteMask - b2 = (val_u64 >> np.uint64(16)) & ByteMask - b3 = (val_u64 >> np.uint64(24)) & ByteMask - b4 = (val_u64 >> np.uint64(32)) & ByteMask - b5 = (val_u64 >> np.uint64(40)) & ByteMask - u8_data.extend([b0, b1, b2, b3, b4, b5]) - elif self.dtype == DType.SHAPE: - for val in self.data: - val_u64 = np.uint64(val) - b0 = val_u64 & ByteMask - b1 = (val_u64 >> np.uint64(8)) & ByteMask - b2 = (val_u64 >> np.uint64(16)) & ByteMask - b3 = (val_u64 >> np.uint64(24)) & ByteMask - b4 = (val_u64 >> np.uint64(32)) & ByteMask - b5 = (val_u64 >> np.uint64(40)) & ByteMask - b6 = (val_u64 >> np.uint64(48)) & ByteMask - b7 = (val_u64 >> np.uint64(56)) & ByteMask - u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7]) - elif self.dtype == DType.FP16: - np_arr = np.array(self.data, dtype=np.float16) - u8_data.extend(np_arr.view(np.uint8)) - elif ( - self.dtype == DType.FP32 - or self.dtype == DType.BF16 - or self.dtype == DType.FP8E4M3 - or self.dtype == DType.FP8E5M2 - ): - # for val in self.data: - # b = struct.pack("!f", val) - # u8_data.extend([b[3], b[2], b[1], b[0]]) - np_arr = np.array(self.data, dtype=np.float32) - u8_data.extend(np_arr.view(np.uint8)) - elif self.dtype == TosaDType.DType: - # Serialize DType enum data as uint8 bytes - for val in self.data: - np_arr = np.array(self.data, dtype=np.uint32) - u8_data.extend(np_arr.view(np.uint8)) - else: - raise Exception( - "unsupported data type {}".format(DTypeNames[self.dtype]) - ) + u8_data = TosaSerializer.convertDataToUint8Vec(self.dtype, self.data) fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data) TosaTensor.Start(builder) @@ -958,3 +883,105 @@ class TosaSerializer: return val else: return [val] + + @staticmethod + def convertDataToUint8Vec(dtype, data): + u8_data = list() + # little endianess + if dtype == DType.BOOL: + for val in data: + val_u8 = np.uint8(val) + u8_data.append(val_u8) + elif dtype == DType.INT4: + in_size = len(data) + out_size = (in_size + 1) // 2 + for i in range(out_size): + val_0 = data[2 * i] + if (2 * i + 1) < in_size: + val_1 = data[2 * i + 1] + else: + val_1 = 0 + val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4) + val_u8 = np.uint8(val_i8) + u8_data.append(val_u8) + elif dtype == DType.INT8: + for val in data: + val_u8 = np.array(val).astype(dtype=np.uint8) + u8_data.append(val_u8) + elif dtype == DType.INT16: + for val in data: + val_u16 = np.array(val).astype(dtype=np.uint16) + b0 = val_u16 & ByteMask + b1 = (val_u16 >> np.uint16(8)) & ByteMask + u8_data.extend([b0, b1]) + elif dtype == DType.INT32: + for val in data: + val_u32 = np.array(val).astype(dtype=np.uint32) + b0 = val_u32 & ByteMask + b1 = (val_u32 >> np.uint32(8)) & ByteMask + b2 = (val_u32 >> np.uint32(16)) & ByteMask + b3 = (val_u32 >> np.uint32(24)) & ByteMask + u8_data.extend([b0, b1, b2, b3]) + elif dtype == DType.INT48: + for val in data: + val_u64 = np.uint64(val) + b0 = val_u64 & ByteMask + b1 = (val_u64 >> np.uint64(8)) & ByteMask + b2 = (val_u64 >> np.uint64(16)) & ByteMask + b3 = (val_u64 >> np.uint64(24)) & ByteMask + b4 = (val_u64 >> np.uint64(32)) & ByteMask + b5 = (val_u64 >> np.uint64(40)) & ByteMask + u8_data.extend([b0, b1, b2, b3, b4, b5]) + elif dtype == DType.SHAPE: + for val in data: + val_u64 = np.uint64(val) + b0 = val_u64 & ByteMask + b1 = (val_u64 >> np.uint64(8)) & ByteMask + b2 = (val_u64 >> np.uint64(16)) & ByteMask + b3 = (val_u64 >> np.uint64(24)) & ByteMask + b4 = (val_u64 >> np.uint64(32)) & ByteMask + b5 = (val_u64 >> np.uint64(40)) & ByteMask + b6 = (val_u64 >> np.uint64(48)) & ByteMask + b7 = (val_u64 >> np.uint64(56)) & ByteMask + u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7]) + elif dtype == DType.FP16: + np_arr = np.array(data, dtype=np.float16) + u8_data.extend(np_arr.view(np.uint8)) + elif dtype == DType.FP32: + # for val in data: + # b = struct.pack("!f", val) + # u8_data.extend([b[3], b[2], b[1], b[0]]) + np_arr = np.array(data, dtype=np.float32) + u8_data.extend(np_arr.view(np.uint8)) + elif dtype == DType.BF16: + for val in data: + # convert val to little endian byte arrays b + b = struct.pack(" [ b[3], b[2], b[1], b[0] ] + # keep only most significant 2 bytes for bf16 + # in little endian ordering + u8_data.extend([b[2], b[3]]) + elif dtype == DType.FP8E4M3: + for val in data: + # convert val to fp8_bits then to single byte + f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] + f32_bits = f"{f32_as_int:032b}" + fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") + u8_data.extend(fp8_bytes) + elif dtype == DType.FP8E5M2: + for val in data: + # convert val to fp8_bits then to single byte + f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] + f32_bits = f"{f32_as_int:032b}" + fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") + u8_data.extend(fp8_bytes) + elif dtype == TosaDType.DType: + # Serialize DType enum data as uint8 bytes + for val in data: + np_arr = np.array(data, dtype=np.uint32) + u8_data.extend(np_arr.view(np.uint8)) + else: + raise Exception("unsupported data type {}".format(DTypeNames[dtype])) + return u8_data diff --git a/python/tosa/ClampAttribute.py b/python/tosa/ClampAttribute.py index 6a41498..1189acb 100644 --- a/python/tosa/ClampAttribute.py +++ b/python/tosa/ClampAttribute.py @@ -82,8 +82,15 @@ class ClampAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 + # ClampAttribute + def Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + def ClampAttributeStart(builder): - builder.StartObject(2) + builder.StartObject(3) def Start(builder): ClampAttributeStart(builder) @@ -112,6 +119,12 @@ def ClampAttributeStartMaxValVector(builder, numElems): def StartMaxValVector(builder, numElems: int) -> int: return ClampAttributeStartMaxValVector(builder, numElems) +def ClampAttributeAddType(builder, type): + builder.PrependUint32Slot(2, type, 0) + +def AddType(builder, type): + ClampAttributeAddType(builder, type) + def ClampAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/PadAttribute.py b/python/tosa/PadAttribute.py index 301bf17..c4084dc 100644 --- a/python/tosa/PadAttribute.py +++ b/python/tosa/PadAttribute.py @@ -55,8 +55,15 @@ class PadAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) return o == 0 + # PadAttribute + def Type(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + def PadAttributeStart(builder): - builder.StartObject(1) + builder.StartObject(2) def Start(builder): PadAttributeStart(builder) @@ -73,6 +80,12 @@ def PadAttributeStartPadConstVector(builder, numElems): def StartPadConstVector(builder, numElems: int) -> int: return PadAttributeStartPadConstVector(builder, numElems) +def PadAttributeAddType(builder, type): + builder.PrependUint32Slot(1, type, 0) + +def AddType(builder, type): + PadAttributeAddType(builder, type) + def PadAttributeEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 79b83b1..7b5948b 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -185,6 +185,7 @@ table TransposeConvAttribute { table PadAttribute { pad_const: [ubyte] (force_align: 8); + type: DType; } table AxisAttribute { @@ -201,6 +202,7 @@ table ResizeAttribute { table ClampAttribute { min_val: [ubyte] (force_align: 8); max_val: [ubyte] (force_align: 8); + type: DType; } table RescaleAttribute { diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 749a3c8..85625cd 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -19,6 +19,9 @@ #include using namespace tosa; +using fp8e4m3 = tosa::float_t; +using fp8e5m2 = tosa::float_t; + TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector* shape, DType dtype, @@ -747,6 +750,51 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector& buf) } } +tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector& in, std::vector& out) +{ + // Note: Converts fp32->bf16 by ignoring the least significant 16 bits + out.clear(); + for (auto val : in) + { + uint32_t* val_u32 = reinterpret_cast(&val); + uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF; + uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF; + // little endian: byte2 followed by byte3 + out.push_back(f32_byte2); + out.push_back(f32_byte3); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector& in, std::vector& out) +{ + // Note: Converts fp32->FP8E4M3 before converting to unint8_t + out.clear(); + for (auto val : in) + { + auto f8 = static_cast(val); + uint8_t b8 = f8.bits(); + out.push_back(b8); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector& in, std::vector& out) +{ + // Note: Converts fp32->FP8E5M2 before converting to uint8_t + out.clear(); + for (auto val : in) + { + auto f8 = static_cast(val); + uint8_t b8 = f8.bits(); + out.push_back(b8); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector& in, std::vector& out) { // Note: Converts fp32->fp16 before converting to uint8_t @@ -896,6 +944,78 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector& in return TOSA_OK; } +tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector& in, + uint32_t out_size, + std::vector& out) +{ + // Note: bf16 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int16_t)) + { + printf("TosaSerializationHandler::ConvertU8toBF16(): uint8 buffer size %ld must >= target size %ld\n", + in.size(), out_size * sizeof(int16_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + uint32_t f32_byte2 = in[i * sizeof(int16_t)]; + uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; + uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + + // Reinterpret u32 bytes as fp32 + float val_f32 = *(float*)&val_u32; + out.push_back(val_f32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector& in, + uint32_t out_size, + std::vector& out) +{ + // Note: FP8E4M3 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int8_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int8_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + int8_t bits = static_cast(in[i * sizeof(int8_t)]); + auto f8 = fp8e4m3::from_bits(bits); + float val_f32 = static_cast(f8); + out.push_back(val_f32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector& in, + uint32_t out_size, + std::vector& out) +{ + // Note: FP8E5M2 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int8_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int8_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + int8_t bits = static_cast(in[i * sizeof(int8_t)]); + auto f8 = fp8e5m2::from_bits(bits); + float val_f32 = static_cast(f8); + out.push_back(val_f32); + } + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector& in, uint32_t out_size, std::vector& out) -- cgit v1.2.1