diff options
26 files changed, 1021 insertions, 790 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f4f851..679603d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ set(public_headers) list(APPEND public_headers include/attribute.h include/attribute.def - include/float_utils.h + include/cfloat.h include/numpy_utils.h include/tosa_generated.h include/tosa_serialization_handler.h diff --git a/include/attribute.def b/include/attribute.def index 30b432d..0e97629 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -43,18 +43,16 @@ DEF_ATTRIBUTE(Conv, 7, bool, S, local_bound, DType, S, acc_type) -DEF_ATTRIBUTE(TransposeConv, 7, +DEF_ATTRIBUTE(TransposeConv, 6, int32_t, V, out_pad, int32_t, V, stride, - int32_t, V, output_shape, int32_t, S, input_zp, int32_t, S, weight_zp, bool, S, local_bound, DType, S, acc_type) -DEF_ATTRIBUTE(Pad, 2, - uint8_t, V, pad_const, - DType, S, type) +DEF_ATTRIBUTE(Pad, 1, + uint8_t, V, pad_const) DEF_ATTRIBUTE(Axis, 1, int32_t, S, axis) @@ -65,10 +63,9 @@ DEF_ATTRIBUTE(Resize, 4, int16_t, V, border, ResizeMode, S, mode) -DEF_ATTRIBUTE(Clamp, 3, +DEF_ATTRIBUTE(Clamp, 2, uint8_t, V, min_val, - uint8_t, V, max_val, - DType, S, type) + uint8_t, V, max_val) DEF_ATTRIBUTE(Rescale, 7, int32_t, S, input_zp, diff --git a/include/cfloat.h b/include/cfloat.h new file mode 100644 index 0000000..cbbe09a --- /dev/null +++ b/include/cfloat.h @@ -0,0 +1,861 @@ +// Copyright (c) 2022-2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CT_CFLOAT_H +#define CT_CFLOAT_H +#include <algorithm> +#include <cstdint> +#include <cstring> +#include <limits> +#include <type_traits> +#if defined(__cpp_lib_bit_cast) +#include <bit> +#endif // defined(__cpp_lib_bit_cast) + +namespace ct +{ +/// \brief Bitfield specification of the features provided of a specified +/// floating point type. +enum class FloatFeatures +{ + None = 0x0, + HasNaN = 0x1, ///< The type can represent NaN values + HasInf = 0x2, ///< The type can represent Infinity + HasDenorms = 0x4, ///< The type can represent denormal/subnormal values +}; + +constexpr FloatFeatures operator&(const FloatFeatures& a, const FloatFeatures& b) +{ + using T = std::underlying_type_t<FloatFeatures>; + return static_cast<FloatFeatures>(static_cast<T>(a) & static_cast<T>(b)); +} + +constexpr FloatFeatures operator|(const FloatFeatures& a, const FloatFeatures& b) +{ + using T = std::underlying_type_t<FloatFeatures>; + return static_cast<FloatFeatures>(static_cast<T>(a) | static_cast<T>(b)); +} + +constexpr FloatFeatures& operator|=(FloatFeatures& a, const FloatFeatures& b) +{ + a = a | b; + return a; +} + +namespace float_support +{ +struct hidden +{}; + +/// \brief Get the number of bytes required to store the given number of +/// bits. +/// +/// NOTE This is distinct from the number of bytes required to represent +/// the number of bits - a power of two number of bytes will always be +/// returned by this method. +constexpr size_t get_storage_bytes(const size_t n_bits) +{ + const size_t n_bytes = (n_bits + 7) / 8; + size_t storage_bytes = 1; + for (; storage_bytes < n_bytes; storage_bytes <<= 1) + ; + return storage_bytes; +} + +/// \brief Utility method to convert from an older representation of the +/// floating-point features to the FloatFeatures bitfield. +constexpr FloatFeatures get_float_flags(bool has_nan, bool has_denorm, bool has_inf) +{ + FloatFeatures r = FloatFeatures::None; + + if (has_nan) + r |= FloatFeatures::HasNaN; + + if (has_denorm) + r |= FloatFeatures::HasDenorms; + + if (has_inf) + r |= FloatFeatures::HasInf; + + return r; +} + +/// \brief Shorthand for all support features +static constexpr FloatFeatures AllFeats = get_float_flags(true, true, true); + +// Map from a number of storage bytes to a suitable storage type +template <size_t n_bytes> +struct storage_type; + +#define STORAGE_TYPE(T) \ + template <> \ + struct storage_type<sizeof(T)> \ + { \ + using type = T; \ + } +STORAGE_TYPE(int8_t); +STORAGE_TYPE(int16_t); +STORAGE_TYPE(int32_t); +STORAGE_TYPE(int64_t); +#undef STORAGE_TYPE + +template <size_t n_storage_bytes> +using storage_type_t = typename storage_type<n_storage_bytes>::type; + +#if defined(__cpp_lib_bit_cast) +#define BITCAST_CONSTEXPR constexpr inline + +// If bit_cast is available then use it + +constexpr inline int32_t get_bits(const float& f) +{ + return std::bit_cast<int32_t>(f); +} +constexpr inline float from_bits(const int32_t& i) +{ + return std::bit_cast<float>(i); +} + +#else +#define BITCAST_CONSTEXPR inline + +// Otherwise `memcpy` is the safe (non-UB) of achieving the same result + +inline int32_t get_bits(const float& f) +{ + int32_t i; + std::memcpy(&i, &f, sizeof(float)); + return i; +} + +inline float from_bits(const int32_t& i) +{ + float f; + std::memcpy(&f, &i, sizeof(float)); + return f; +} +#endif + +} // namespace float_support + +/// \brief Overflow mode for narrowing floating-point casts. +/// +/// Determine the behaviour for values which cannot be represented by the +/// destination type. +enum class OverflowMode +{ + Saturate, ///< Map to the largest representable value + Overflow ///< Map to infinity (if available) or NaN +}; + +/// Functor for casting cfloat_advanced +/// +/// Specific casting behavior can be specified when constructing the +/// functor. +/// +/// By default, OVERFLOW mode is used when the destination type has either +/// infinity or NaN representations. Otherwise SATURATE mode is used. It is +/// illegal to specify OVERFLOW mode for a type which has neither infinity +/// or NaN representations - this will result in a compilation error. +template <class in_type, + class out_type, + OverflowMode overflow_mode = + (out_type::has_nan || out_type::has_inf) ? OverflowMode::Overflow : OverflowMode::Saturate> +class cfloat_cast +{ + constexpr static FloatFeatures in_feats = in_type::features; + constexpr static FloatFeatures out_feats = out_type::features; + constexpr static size_t in_bits = in_type::n_bits; + constexpr static size_t in_exp_bits = in_type::n_exponent_bits; + constexpr static size_t out_bits = out_type::n_bits; + constexpr static size_t out_exp_bits = out_type::n_exponent_bits; + +public: + constexpr cfloat_cast() + { + // SATURATE mode MUST be specified if the destination type does not + // have either NaN or infinity representations. + static_assert(overflow_mode == OverflowMode::Saturate || out_type::has_nan || out_type::has_inf); + } + + /// \brief Cast from `in` to the given `out_type` + // + // This code relies on an understanding of the storage format used by + // `cfloat_advanced`. See the documentation of that class for further + // details. + constexpr out_type operator()(const in_type& in) const + { + // Shortcut for types which differ only in the number of significand + // bits, and where the output type is wider than the input type. For + // example, bfloat16 and binary32. + if constexpr (in_exp_bits == out_exp_bits && out_bits >= in_bits && in_feats == out_feats) + { + return out_type::from_bits(static_cast<typename out_type::storage_t>(in.bits()) << (out_bits - in_bits)); + } + + // Get initial values for the new floating point type + const bool sign_bit = in.sign(); + int64_t new_exponent_bits = 0; + uint64_t new_significand = 0; + + if (in.is_nan() || in.is_infinity()) + { + // The mapping of infinity to the destination type depends upon + // the overflow mode and the features of the destination type. + // OVERFLOW mode is the "expected" behaviour, in which exception + // values (NaN and infinity) map to themselves in the + // destination type (assuming they exist). In SATURATION mode, + // infinity maps to the largest absolute value of the + // destination type _even if_ an infinity encoding is available. + // See the FP8 specification document. + // + // By default, exceptional values are encoded with an all-1 + // exponent field. + new_exponent_bits = (UINT64_C(1) << out_exp_bits) - 1; + + if (in.is_nan()) + { + // NaN always maps to NaN if it's available. + // + // NB: if the type has both NaN AND Infinity support, then + // the entirety of the significand can be used to encode + // different values of NaN (excepting significand = 0, + // which is reserved for infinity). This makes it possible + // to encode both quiet and signalling varieties. + // Generally, the LSB of the significand represents "not + // quiet". However, when there is only 1 NaN encoding + // (which is generally the case when infinity is not + // supported), then there cannot be separate quiet and + // signalling varieties of NaN. + if constexpr (out_type::has_inf) + { + // Copy across the `not_quiet bit`; set the LSB. + // Don't attempt to copy across any of the rest of + // the payload. + new_significand = 0x1 | (((in.significand() >> (in_type::n_significand_bits - 1)) & 1) + << out_type::n_significand_bits); + } + else + { + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + } + } + else if constexpr (overflow_mode == OverflowMode::Saturate) + { + // In SATURATE mode, infinity in the input maps to the + // largest absolute value in the output type; even if + // infinity is available. This is in compliance with Table 3 + // of the FP8 specification. + return out_type::max(sign_bit); + } + else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Overflow) + { + // In OVERFLOW mode, infinities in the input type map to NaN + // in the output type, if infinity is not available. + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + } + } + else if (!in.is_zero()) + { + const int64_t this_exponent_bits = in.exponent_bits(); + { + constexpr int64_t exponent_rebias = out_type::exponent_bias - in_type::exponent_bias; + new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); + } + new_significand = in.significand() << (64 - in_type::n_significand_bits); + + // Normalise subnormals + if (this_exponent_bits == 0) + { + // Shift the most-significant 1 out of the magnitude to + // convert it to a significand. Modify the exponent + // accordingly. + uint8_t shift = __builtin_clzl(new_significand) + 1; + new_exponent_bits -= shift; + new_significand <<= shift; + } + + // Apply overflow to out-of-range values; this must occur before + // rounding, as out-of-range values could be rounded down to the + // largest representable value. + if constexpr (overflow_mode == OverflowMode::Overflow) + { + // Determine the maximum value of exponent, and unrounded + // significand. + constexpr bool inf_and_nan = out_type::has_nan && out_type::has_inf; + constexpr int64_t max_exp_bits = (INT64_C(1) << out_exp_bits) - (inf_and_nan ? 2 : 1); + constexpr uint64_t max_significand = + ((UINT64_C(1) << out_type::n_significand_bits) - (inf_and_nan ? 1 : 2)) + << (64 - out_type::n_significand_bits); + + // If the exponent is strictly larger than the largest + // possible, or the exponent is equal to the largest + // possible AND the (unrounded) significand is strictly + // larger than the largest possible then return an + // appropriate overflow value. + if (new_exponent_bits > max_exp_bits || + (new_exponent_bits == max_exp_bits && new_significand > max_significand)) + { + if constexpr (out_type::has_inf) + return out_type::infinity(sign_bit); + else + return out_type::NaN(); + } + } + + // Align the significand for the output type + uint32_t shift = 64 - out_type::n_significand_bits; + const bool other_is_subnormal = new_exponent_bits <= 0; + if (other_is_subnormal) + { + shift += 1 - new_exponent_bits; + new_exponent_bits = 0; + } + + const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((UINT64_C(1) << shift) - 1); + new_significand = shift == 64 ? 0 : new_significand >> shift; + + // Reinsert the most-significant-one if this is a subnormal + // in the output type. + new_significand |= (other_is_subnormal ? UINT64_C(1) : 0) << (64 - shift); + + // Apply rounding based on the bits shifted out of the + // significand + const uint64_t shift_half = UINT64_C(1) << (shift - 1); + if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) + { + new_significand += 1; + + // Handle the case that the significand overflowed due + // to rounding + constexpr uint64_t max_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + if (new_significand > max_significand) + { + new_significand = 0; + new_exponent_bits++; + } + } + + // Saturate or overflow if the value is larger than can be + // represented in the output type. This can only occur if the + // size of the exponent of the new type is not greater than the + // exponent of the old type. + if constexpr (out_exp_bits <= in_exp_bits) + { + constexpr int64_t inf_exp_bits = (INT64_C(1) << out_exp_bits) - 1; + if (new_exponent_bits >= inf_exp_bits) + { + if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Overflow) + { + // If the output type has a representation of + // infinity, and we are in OVERFLOW Mode, then + // return infinity. + new_exponent_bits = inf_exp_bits; + new_significand = 0; + } + else if constexpr (out_type::has_inf) + { + // If the output type has a representation of + // infinity, and we are in SATURATE mode, then + // return the largest representable real number. + new_exponent_bits = inf_exp_bits - 1; + new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; + } + else if (new_exponent_bits > inf_exp_bits) + { + if constexpr (overflow_mode == OverflowMode::Overflow) + return out_type::NaN(); + else + return out_type::max(sign_bit); + } + else + { + constexpr uint64_t max_significand = + (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1); + if (new_significand > max_significand) + { + if constexpr (overflow_mode == OverflowMode::Saturate) + new_significand = max_significand; + else + return out_type::NaN(); + } + } + } + } + } + + return out_type::from_bits(sign_bit, new_exponent_bits, new_significand); + } +}; + +/// \brief Bit-accurate representation storage of IEEE754 compliant and +/// derived floating point types. +/// +/// Template parameters allow for specification of the number of bits, the +/// number of exponent bits, and the features of the floating point types. +/// The number of significand bits is `n_bits - n_exponent_bits - 1`. It is +/// not possible to represent a signless type, such as FP8 E8M0. +/// +/// For an imaginary 7-bit type, FP7 E4M2; the storage for various values +/// given different floating point features is given below: +/// +/// Value All features No infinity No features +/// -------------------------- ------------ ----------- ----------- +/// Positive zero +0 00 0000 00 As before As before +/// Negative zero -0 11 0000 00 As before As before +/// Positive/negative infinity SS 1111 00 N/A N/A +/// Signalling NaN SS 1111 01 SS 1111 11 N/A +/// Quiet NaN SS 1111 11 N/A N/A +/// Largest normal SS 1110 11 SS 1111 10 SS 1111 11 +/// Smallest normal SS 0001 00 As before SS 0000 01 +/// Largest denormal SS 0000 11 SS 0000 11 N/A +/// +/// Note that the sign bit is extended to fill the storage type. +template <size_t _n_bits, size_t n_exp_bits, FloatFeatures Feats = float_support::AllFeats> +class cfloat_advanced +{ +public: + using storage_t = float_support::storage_type_t<float_support::get_storage_bytes(_n_bits)>; + + static constexpr size_t n_bits = _n_bits; + static constexpr size_t n_exponent_bits = n_exp_bits; + static constexpr size_t n_significand_bits = n_bits - (1 + n_exp_bits); + static constexpr int64_t exponent_bias = (INT64_C(1) << (n_exp_bits - 1)) - 1; + + static constexpr FloatFeatures features = Feats; + static constexpr bool has_nan = (Feats & FloatFeatures::HasNaN) != FloatFeatures::None; + static constexpr bool has_inf = (Feats & FloatFeatures::HasInf) != FloatFeatures::None; + static constexpr bool has_denorms = (Feats & FloatFeatures::HasDenorms) != FloatFeatures::None; + + /// \brief Construct a floating point type with the given bit + /// representation. + static constexpr cfloat_advanced from_bits(storage_t bits) + { + return cfloat_advanced(float_support::hidden(), bits); + } + + /// \brief Construct a float from the given sign, exponent and + /// significand bits. + static constexpr cfloat_advanced from_bits(bool pm, storage_t e, storage_t s) + { + storage_t bits = pm ? -1 : 0; + + bits <<= n_exp_bits; + bits |= e; + + bits <<= n_significand_bits; + if (has_denorms || e) + bits |= s; + + return cfloat_advanced(float_support::hidden(), bits); + } + + /// \brief (Hidden) Construct a float type from a given bit pattern + constexpr cfloat_advanced(const float_support::hidden&, storage_t bits) + : m_data(bits) + {} + + constexpr cfloat_advanced() + : m_data(0) + {} + constexpr cfloat_advanced(const cfloat_advanced& other) + : m_data(other.m_data) + {} + + constexpr cfloat_advanced& operator=(const cfloat_advanced& other) + { + this->m_data = other.m_data; + return *this; + } + + constexpr cfloat_advanced& operator=(cfloat_advanced&& other) + { + this->m_data = other.m_data; + return *this; + } + + /// \brief Get a NaN representation + static constexpr cfloat_advanced NaN() + { + static_assert(has_nan); + + // NaN is always encoded with all 1s in the exponent. + // If Inf exists, then NaN is encoded as a non-zero significand; if + // Inf doesn't exist then NaN is encoded as all ones in the + // significand. + constexpr uint64_t exp_bits = (UINT64_C(1) << n_exponent_bits) - 1; + constexpr uint64_t sig_bits = has_inf ? 1 : (UINT64_C(1) << n_significand_bits) - 1; + return cfloat_advanced::from_bits(false, exp_bits, sig_bits); + } + + /// \brief Get a representation of infinity + static constexpr cfloat_advanced infinity(const bool& sign) + { + static_assert(has_inf); + + // Inf is always encoded with all 1s in the exponent, and all zeros + // in the significand. + return cfloat_advanced::from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, 0); + } + + /// \brief Get the largest representable value + static constexpr cfloat_advanced max(const bool& sign) + { + if constexpr (has_nan && has_inf) + { + // Where we have NaN and Infinity, exponents all `1` corresponds + // to some of these values. + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1); + } + else if constexpr (has_nan || has_inf) + { + // Where we have either NaN or infinity (but not both), + // exponents all `1` AND significand all `1` corresponds to the + // special value. + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2); + } + else + { + // With no special values to encode, the maximum value is + // encoded as all `1`s. + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1); + } + } + + /// \brief Cast to a different floating point representation. + template <size_t out_n_bits, size_t out_n_exp_bits, FloatFeatures OutFeats> + constexpr inline operator cfloat_advanced<out_n_bits, out_n_exp_bits, OutFeats>() const + { + using out_type = cfloat_advanced<out_n_bits, out_n_exp_bits, OutFeats>; + return cfloat_cast<cfloat_advanced, out_type>().operator()(*this); + } + + /// \brief Convert from a 32-bit floating point value + BITCAST_CONSTEXPR + cfloat_advanced(const float& f) + { + // If this format exactly represents the binary32 format then get + // the bits from the provided float; otherwise get a binary32 + // representation and then convert to this format. + if constexpr (represents_binary32()) + m_data = float_support::get_bits(f); + else + m_data = + static_cast<cfloat_advanced<n_bits, n_exp_bits, Feats>>(static_cast<cfloat_advanced<32, 8>>(f)).m_data; + } + + /// \brief Cast to a 32-bit floating point value + BITCAST_CONSTEXPR operator float() const + { + // If this format exactly represents the binary32 format then return + // a float; otherwise get a binary32 representation and then return + // a float. + if constexpr (represents_binary32()) + return float_support::from_bits(m_data); + else + return static_cast<float>(this->operator cfloat_advanced<32, 8>()); + } + + /// \brief Return whether this type represents the IEEE754 binary32 + /// format + constexpr static inline bool represents_binary32() + { + return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && Feats == float_support::AllFeats; + } + + constexpr auto operator-() const + { + constexpr storage_t sign_bits = + static_cast<storage_t>(std::numeric_limits<std::make_unsigned_t<storage_t>>::max() << (n_bits - 1)); + return from_bits(m_data ^ sign_bits); + } + + constexpr bool is_subnormal() const + { + return exponent_bits() == 0 && significand() != 0; + } + + constexpr bool is_zero() const + { + return exponent_bits() == 0 && significand() == 0; + } + + constexpr bool is_nan() const + { + return has_nan && (exponent_bits() == (UINT64_C(1) << n_exponent_bits) - 1) && + ((has_inf && significand()) || (!has_inf && significand() == (UINT64_C(1) << n_significand_bits) - 1)); + } + + constexpr bool is_infinity() const + { + return has_inf && ((exponent_bits() == (UINT64_C(1) << n_exponent_bits) - 1) && (significand() == 0)); + } + + constexpr inline const storage_t& bits() const + { + return m_data; + } + + /// \brief Get the exponent + constexpr inline int64_t exponent() const + { + return std::max<int64_t>(exponent_bits(), INT64_C(1)) - exponent_bias; + } + + /// \brief Get the sign bit + constexpr inline bool sign() const + { + return (m_data >> (n_bits - 1)) & 0x1; + } + + /// \brief Get the bits from the exponent + constexpr inline uint64_t exponent_bits() const + { + constexpr uint64_t mask = (UINT64_C(1) << n_exp_bits) - 1; + return (m_data >> n_significand_bits) & mask; + } + + constexpr inline uint64_t significand() const + { + return m_data & ((UINT64_C(1) << n_significand_bits) - 1); + } + + constexpr inline bool operator==(const cfloat_advanced& other) const + { + return !is_nan() && !other.is_nan() && // Neither operand is NaN + ((is_zero() && other.is_zero()) || (m_data == other.m_data)); + } + + constexpr inline bool operator!=(const cfloat_advanced& other) const + { + return !(*this == other); + } + + constexpr inline cfloat_advanced& operator+=(const cfloat_advanced& rhs) + { + this->m_data = static_cast<cfloat_advanced>(static_cast<float>(*this) + static_cast<float>(rhs)).bits(); + return *this; + } + +private: + storage_t m_data = 0; +}; + +// This should probably be exported so we can use it elsewhere +#undef BITCAST_CONSTEXPR + +/// \brief Wrapper to maintain API compatibility with older code, which was +/// limited to power-of-two sizes of floats. +template <typename storage_t, + size_t n_exp_bits, + bool has_nan, + bool with_denorm, + bool with_infinity, + std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true> +using cfloat = cfloat_advanced<sizeof(storage_t) * 8, + n_exp_bits, + float_support::get_float_flags(has_nan, with_denorm, with_infinity)>; + +namespace float_support +{ +// Pre-C++23 these can't be computed as constexpr, so have to hardcode +// them + +template <int> +struct digits10; // floor(log10(2) * (digits - 1) +template <int> +struct max_digits10; // ceil(log10(2) * digits + 1) +template <int> +struct min_exponent10; // floor(log10(2) * min_exponent) +template <int> +struct max_exponent10; // floor(log10(2) * max_exponent) + +template <> +struct digits10<8> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<8> +{ + constexpr static inline int value = 4; +}; + +template <> +struct digits10<10> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<10> +{ + constexpr static inline int value = 5; +}; + +template <> +struct digits10<24> +{ + constexpr static inline int value = 6; +}; + +template <> +struct max_digits10<24> +{ + constexpr static inline int value = 9; +}; + +template <> +struct min_exponent10<-13> +{ + constexpr static inline int value = -3; +}; + +template <> +struct max_exponent10<16> +{ + constexpr static inline int value = 4; +}; + +template <> +struct min_exponent10<-125> +{ + constexpr static inline int value = -37; +}; + +template <> +struct max_exponent10<128> +{ + constexpr static inline int value = 38; +}; + +template <int d> +inline constexpr int digits10_v = digits10<d>::value; +template <int d> +inline constexpr int max_digits10_v = max_digits10<d>::value; + +template <int e> +inline constexpr int min_exponent10_v = min_exponent10<e>::value; + +template <int e> +inline constexpr int max_exponent10_v = max_exponent10<e>::value; + +} // namespace float_support + +} // namespace ct + +namespace std +{ + +template <size_t n_bits, size_t n_exp_bits, ct::FloatFeatures Feats> +struct is_floating_point<ct::cfloat_advanced<n_bits, n_exp_bits, Feats>> : std::integral_constant<bool, true> +{}; + +template <size_t n_bits, size_t n_exp_bits, ct::FloatFeatures Feats> +class numeric_limits<ct::cfloat_advanced<n_bits, n_exp_bits, Feats>> +{ + using this_cfloat = ct::cfloat_advanced<n_bits, n_exp_bits, Feats>; + +public: + static constexpr bool is_specialized = true; + + static constexpr auto min() noexcept + { + return this_cfloat::from_bits(false, 1, 0); + } + + static constexpr auto max() noexcept + { + return this_cfloat::max(false); + } + static constexpr auto lowest() noexcept + { + return -max(); + } + + static constexpr int digits = this_cfloat::n_significand_bits + 1; + static constexpr int digits10 = ct::float_support::digits10_v<digits>; + static constexpr int max_digits10 = ct::float_support::max_digits10_v<digits>; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + + static constexpr auto epsilon() noexcept + { + return this_cfloat::from_bits(false, this_cfloat::exponent_bias - this_cfloat::n_significand_bits, 0); + } + + static constexpr auto round_error() noexcept + { + return this_cfloat::from_bits(0, this_cfloat::exponent_bias - 1, 0); + } + + static constexpr int min_exponent = (1 - this_cfloat::exponent_bias) + 1; + static constexpr int min_exponent10 = ct::float_support::min_exponent10_v<min_exponent>; + static constexpr int max_exponent = this_cfloat::exponent_bias + 1; + static constexpr int max_exponent10 = ct::float_support::max_exponent10_v<max_exponent>; + + static constexpr bool has_infinity = this_cfloat::has_inf; + static constexpr bool has_quiet_NaN = this_cfloat::has_nan && this_cfloat::has_inf; + static constexpr bool has_signaling_NaN = this_cfloat::has_nan; + static constexpr float_denorm_style has_denorm = this_cfloat::has_denorms ? denorm_present : denorm_absent; + static constexpr bool has_denorm_loss = false; + + static constexpr auto infinity() noexcept + { + if constexpr (this_cfloat::has_inf) + { + return this_cfloat::infinity(false); + } + else + { + return this_cfloat::from_bits(false, 0, 0); + } + } + + static constexpr auto quiet_NaN() noexcept + { + const uint64_t exp_bits = (UINT64_C(1) << this_cfloat::n_exponent_bits) - 1; + const uint64_t sig_bits = this_cfloat::has_inf ? (UINT64_C(1) << (this_cfloat::n_significand_bits - 1)) | 1 + : (UINT64_C(1) << this_cfloat::n_significand_bits) - 1; + return this_cfloat::from_bits(false, exp_bits, sig_bits); + } + + static constexpr auto signaling_NaN() noexcept + { + const uint64_t exp_bits = (UINT64_C(1) << this_cfloat::n_exponent_bits) - 1; + const uint64_t sig_bits = this_cfloat::has_inf ? 1 : (UINT64_C(1) << this_cfloat::n_significand_bits) - 1; + return this_cfloat::from_bits(false, exp_bits, sig_bits); + } + + static constexpr auto denorm_min() noexcept + { + return this_cfloat::from_bits(false, 0, 1); + } + + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std + +#endif // CT_CFLOAT_H diff --git a/include/float_utils.h b/include/float_utils.h deleted file mode 100644 index 831ad74..0000000 --- a/include/float_utils.h +++ /dev/null @@ -1,533 +0,0 @@ -// Copyright (c) 2024, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TOSA_FLOAT_UTILS_H_ -#define TOSA_FLOAT_UTILS_H_ - -#include <algorithm> -#include <cstdint> -#include <limits> -#include <type_traits> -#if defined(__cpp_lib_bit_cast) -#include <bit> -#endif // defined(__cpp_lib_bit_cast) - -namespace tosa -{ - -namespace float_support -{ - -struct hidden -{}; - -#if defined(__cpp_lib_bit_cast) -#define BITCAST_CONSTEXPR constexpr inline - -constexpr inline int32_t get_bits(const float& f) -{ - return std::bit_cast<int32_t>(f); -} -constexpr inline float from_bits(const int32_t& i) -{ - return std::bit_cast<float>(i); -} - -#else -#define BITCAST_CONSTEXPR inline - -union ufloat32 -{ - constexpr ufloat32(const float& x) - : f(x) - {} - constexpr ufloat32(const int32_t& x) - : i(x) - {} - - float f; - int32_t i; -}; - -inline int32_t get_bits(const float& f) -{ - return ufloat32(f).i; -} -inline float from_bits(const int32_t& i) -{ - return ufloat32(i).f; -} -#endif - -} // namespace float_support - -template <typename storage_t, - size_t n_exp_bits, - bool has_nan, - bool with_denorm, - bool with_infinity, - std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true> -class float_t -{ - storage_t m_data = 0; - -public: - static constexpr size_t n_exponent_bits = n_exp_bits; - static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; - static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; - - /// \brief Construct a floating point type with the given bit - /// representation. - static constexpr float_t from_bits(storage_t bits) - { - return float_t(float_support::hidden(), bits); - } - - /// \brief Construct a float from the given sign, exponent and significand - /// bits. - static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) - { - storage_t bits = pm ? 1 : 0; - - bits <<= n_exp_bits; - bits |= e; - - bits <<= n_significand_bits; - if (with_denorm || e) - bits |= s; - - return float_t(float_support::hidden(), bits); - } - - /// \brief (Hidden) Construct a float type from a given bit pattern - constexpr float_t(const float_support::hidden&, storage_t bits) - : m_data(bits) - {} - - constexpr float_t() - : m_data(0) - {} - constexpr float_t(const float_t& other) - : m_data(other.m_data) - {} - - /// \brief Cast to a different floating point representation. - template <typename other_storage_t, - size_t other_n_exp_bits, - bool other_has_nan, - bool other_has_denorm, - bool other_has_infinity> - constexpr inline - operator float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>() const - { - using other_float_t = - float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>; - - // Shortcut for types which are fundamentally similar (e.g., bf16 -> - // fp32) - if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && - has_nan == other_has_nan) - { - return other_float_t::from_bits(static_cast<other_storage_t>(m_data) - << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); - } - - // Get initial values for the new floating point type - const bool sign_bit = m_data < 0; - int64_t new_exponent_bits = 0; - uint64_t new_significand = 0; - - if (is_nan() || is_infinity()) - { - new_exponent_bits = (1 << other_n_exp_bits) - 1; - - if (is_nan()) - { - if constexpr (other_has_infinity) - { - // Copy across the `not_quiet bit`; set the LSB. Don't - // attempt to copy across any of the rest of the payload. - new_significand = - 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); - } - else - { - new_significand = (1ul << other_float_t::n_significand_bits) - 1; - } - } - else if constexpr (!other_has_infinity) - { - new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); - } - } - else if (!is_zero()) - { - const int64_t this_exponent_bits = exponent_bits(); - { - constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; - new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); - } - new_significand = this->significand() << (64 - n_significand_bits); - - // Normalise subnormals - if (this_exponent_bits == 0) - { - // Shift the most-significant 1 out of the magnitude to convert - // it to a significand. Modify the exponent accordingly. - uint8_t shift = __builtin_clzl(new_significand) + 1; - new_exponent_bits -= shift; - new_significand <<= shift; - } - - // Align the significand for the output type - uint32_t shift = 64 - other_float_t::n_significand_bits; - const bool other_is_subnormal = new_exponent_bits <= 0; - if (other_is_subnormal) - { - shift += 1 - new_exponent_bits; - new_exponent_bits = 0; - } - - const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); - new_significand = shift == 64 ? 0 : new_significand >> shift; - - // Reinsert the most-significant-one if this is a subnormal in the - // output type. - new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); - - // Apply rounding based on the bits shifted out of the significand - const uint64_t shift_half = 1ll << (shift - 1); - if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) - { - new_significand += 1; - - // Handle the case that the significand overflowed due to - // rounding - constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; - if (new_significand > max_significand) - { - new_significand = 0; - new_exponent_bits++; - } - } - - // Saturate to infinity if the exponent is larger than can be - // represented in the output type. This can only occur if the size - // of the exponent of the new type is not greater than the exponent - // of the old type. - if constexpr (other_n_exp_bits <= n_exp_bits) - { - constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; - if (new_exponent_bits >= inf_exp_bits) - { - new_exponent_bits = inf_exp_bits; - new_significand = - other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); - } - } - } - - return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); - } - - /// \brief Convert from a 32-bit floating point value - BITCAST_CONSTEXPR - float_t(const float& f) - { - // If this format exactly represents the binary32 format then get - // the bits from the provided float; otherwise get a binary32 - // representation and then convert to this format. - if constexpr (represents_binary32()) - m_data = float_support::get_bits(f); - else - m_data = static_cast<float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_infinity>>( - static_cast<float_t<int32_t, 8, true, true, true>>(f)) - .m_data; - } - - /// \brief Cast to a 32-bit floating point value - BITCAST_CONSTEXPR operator float() const - { - // If this format exactly represents the binary32 format then return - // a float; otherwise get a binary32 representation and then return - // a float. - if constexpr (represents_binary32()) - return float_support::from_bits(m_data); - else - return static_cast<float>(this->operator float_t<int32_t, 8, true, true, true>()); - } - - /// \brief Return whether this type represents the IEEE754 binary32 - /// format - constexpr static inline bool represents_binary32() - { - return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; - } - - constexpr auto operator-() const - { - return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); - } - - constexpr bool is_subnormal() const - { - return exponent_bits() == 0 && significand() != 0; - } - - constexpr bool is_zero() const - { - return exponent_bits() == 0 && significand() == 0; - } - - constexpr bool is_nan() const - { - return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && - ((with_infinity && significand()) || - (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); - } - - constexpr bool is_infinity() const - { - return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); - } - - constexpr inline const storage_t& bits() const - { - return m_data; - } - - /// \brief Get the exponent - constexpr inline int64_t exponent() const - { - return std::max<int64_t>(exponent_bits(), 1ul) - exponent_bias; - } - - /// \brief Get the bits from the exponent - constexpr inline uint64_t exponent_bits() const - { - constexpr uint64_t mask = (1ul << n_exp_bits) - 1; - return (m_data >> n_significand_bits) & mask; - } - - constexpr inline uint64_t significand() const - { - return m_data & ((1ul << n_significand_bits) - 1); - } - - constexpr inline bool operator==(const float_t& other) const - { - return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); - } - - constexpr inline float_t& operator+=(const float_t& rhs) - { - this->m_data = static_cast<float_t>(static_cast<float>(*this) + static_cast<float>(rhs)).bits(); - return *this; - } -}; - -// This should probably be exported so we can use it elsewhere -#undef BITCAST_CONSTEXPR - -namespace float_support -{ - -// Pre-C++23 these can't be computed as constexpr, so have to hardcode them - -template <int> -struct digits10; // floor(log10(2) * (digits - 1) -template <int> -struct max_digits10; // ceil(log10(2) * digits + 1) -template <int> -struct min_exponent10; // floor(log10(2) * min_exponent) -template <int> -struct max_exponent10; // floor(log10(2) * max_exponent) - -template <> -struct digits10<8> -{ - constexpr static inline int value = 2; -}; - -template <> -struct max_digits10<8> -{ - constexpr static inline int value = 4; -}; - -template <> -struct digits10<10> -{ - constexpr static inline int value = 2; -}; - -template <> -struct max_digits10<10> -{ - constexpr static inline int value = 5; -}; - -template <> -struct digits10<24> -{ - constexpr static inline int value = 6; -}; - -template <> -struct max_digits10<24> -{ - constexpr static inline int value = 9; -}; - -template <> -struct min_exponent10<-13> -{ - constexpr static inline int value = -3; -}; - -template <> -struct max_exponent10<16> -{ - constexpr static inline int value = 4; -}; - -template <> -struct min_exponent10<-125> -{ - constexpr static inline int value = -37; -}; - -template <> -struct max_exponent10<128> -{ - constexpr static inline int value = 38; -}; - -template <int d> -inline constexpr int digits10_v = digits10<d>::value; -template <int d> -inline constexpr int max_digits10_v = max_digits10<d>::value; - -template <int e> -inline constexpr int min_exponent10_v = min_exponent10<e>::value; - -template <int e> -inline constexpr int max_exponent10_v = max_exponent10<e>::value; - -} // namespace float_support - -} // namespace tosa - -namespace std -{ - -template <typename storage_t, size_t n_exp_bits, bool has_nan, bool has_denorm, bool has_inf> -struct is_floating_point<tosa::float_t<storage_t, n_exp_bits, has_nan, has_denorm, has_inf>> - : std::integral_constant<bool, true> -{}; - -template <typename storage_t, size_t n_exp_bits, bool has_nan, bool with_denorm, bool with_inf> -class numeric_limits<tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>> -{ - using this_float_t = tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>; - -public: - static constexpr bool is_specialized = true; - - static constexpr auto min() noexcept - { - return this_float_t::from_bits(false, 1, 0); - } - - static constexpr auto max() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, - (1 << this_float_t::n_significand_bits) - 1); - } - - static constexpr auto lowest() noexcept - { - return -max(); - } - - static constexpr int digits = this_float_t::n_significand_bits + 1; - static constexpr int digits10 = tosa::float_support::digits10_v<digits>; - static constexpr int max_digits10 = tosa::float_support::max_digits10_v<digits>; - - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr int radix = 2; - - static constexpr auto epsilon() noexcept - { - return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); - } - - static constexpr auto round_error() noexcept - { - return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); - } - - static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; - static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v<min_exponent>; - static constexpr int max_exponent = this_float_t::exponent_bias + 1; - static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v<max_exponent>; - - static constexpr bool has_infinity = with_inf; - static constexpr bool has_quiet_NaN = has_nan; - static constexpr bool has_signaling_NaN = true; - static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; - static constexpr bool has_denorm_loss = false; - - static constexpr auto infinity() noexcept - { - if constexpr (with_inf) - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); - } - else - { - return this_float_t::from_bits(false, 0, 0); - } - } - - static constexpr auto quiet_NaN() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, - 1 << (this_float_t::n_significand_bits - 1) | 1); - } - - static constexpr auto signaling_NaN() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); - } - - static constexpr auto denorm_min() noexcept - { - return this_float_t::from_bits(false, 0, 1); - } - - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = false; - static constexpr bool is_modulo = false; - - static constexpr bool traps = false; - static constexpr bool tinyness_before = false; - static constexpr float_round_style round_style = round_to_nearest; -}; - -} // namespace std - -#endif // TOSA_FLOAT_UTILS_H_ diff --git a/include/numpy_utils.h b/include/numpy_utils.h index 60cf77e..ade2f2d 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -24,8 +24,13 @@ #include <cstring> #include <vector> +#include "cfloat.h" #include "half.hpp" +using bf16 = ct::cfloat<int16_t, 8, true, true, true>; +using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; +using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; + class NumpyUtilities { public: @@ -85,6 +90,18 @@ public: { return "'<f2'"; } + if (std::is_same<T, bf16>::value) + { + return "'<V2'"; + } + if (std::is_same<T, fp8e4m3>::value) + { + return "'<V1'"; + } + if (std::is_same<T, fp8e5m2>::value) + { + return "'<f1'"; + } assert(false && "unsupported Dtype"); }; diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 0798256..c907c89 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -8,9 +8,9 @@ // Ensure the included flatbuffers.h is the same version as when this file was // generated, otherwise it may not be compatible. -static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && - FLATBUFFERS_VERSION_MINOR == 5 && - FLATBUFFERS_VERSION_REVISION == 26, +static_assert(FLATBUFFERS_VERSION_MAJOR == 24 && + FLATBUFFERS_VERSION_MINOR == 3 && + FLATBUFFERS_VERSION_REVISION == 7, "Non-compatible flatbuffers version included"); namespace tosa { @@ -883,11 +883,10 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_OUT_PAD = 4, VT_STRIDE = 6, - VT_OUTPUT_SHAPE = 8, - VT_INPUT_ZP = 10, - VT_WEIGHT_ZP = 12, - VT_LOCAL_BOUND = 14, - VT_ACC_TYPE = 16 + VT_INPUT_ZP = 8, + VT_WEIGHT_ZP = 10, + VT_LOCAL_BOUND = 12, + VT_ACC_TYPE = 14 }; const ::flatbuffers::Vector<int32_t> *out_pad() const { return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUT_PAD); @@ -895,9 +894,6 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T const ::flatbuffers::Vector<int32_t> *stride() const { return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_STRIDE); } - const ::flatbuffers::Vector<int32_t> *output_shape() const { - return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SHAPE); - } int32_t input_zp() const { return GetField<int32_t>(VT_INPUT_ZP, 0); } @@ -916,8 +912,6 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T verifier.VerifyVector(out_pad()) && VerifyOffset(verifier, VT_STRIDE) && verifier.VerifyVector(stride()) && - VerifyOffset(verifier, VT_OUTPUT_SHAPE) && - verifier.VerifyVector(output_shape()) && VerifyField<int32_t>(verifier, VT_INPUT_ZP, 4) && VerifyField<int32_t>(verifier, VT_WEIGHT_ZP, 4) && VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) && @@ -936,9 +930,6 @@ struct TransposeConvAttributeBuilder { void add_stride(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride) { fbb_.AddOffset(TransposeConvAttribute::VT_STRIDE, stride); } - void add_output_shape(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> output_shape) { - fbb_.AddOffset(TransposeConvAttribute::VT_OUTPUT_SHAPE, output_shape); - } void add_input_zp(int32_t input_zp) { fbb_.AddElement<int32_t>(TransposeConvAttribute::VT_INPUT_ZP, input_zp, 0); } @@ -966,7 +957,6 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut ::flatbuffers::FlatBufferBuilder &_fbb, ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> out_pad = 0, ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> output_shape = 0, int32_t input_zp = 0, int32_t weight_zp = 0, bool local_bound = false, @@ -975,7 +965,6 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut builder_.add_acc_type(acc_type); builder_.add_weight_zp(weight_zp); builder_.add_input_zp(input_zp); - builder_.add_output_shape(output_shape); builder_.add_stride(stride); builder_.add_out_pad(out_pad); builder_.add_local_bound(local_bound); @@ -986,19 +975,16 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *out_pad = nullptr, const std::vector<int32_t> *stride = nullptr, - const std::vector<int32_t> *output_shape = nullptr, int32_t input_zp = 0, int32_t weight_zp = 0, bool local_bound = false, tosa::DType acc_type = tosa::DType_UNKNOWN) { auto out_pad__ = out_pad ? _fbb.CreateVector<int32_t>(*out_pad) : 0; auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0; - auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0; return tosa::CreateTransposeConvAttribute( _fbb, out_pad__, stride__, - output_shape__, input_zp, weight_zp, local_bound, @@ -1008,20 +994,15 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef PadAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_PAD_CONST = 4, - VT_TYPE = 6 + VT_PAD_CONST = 4 }; const ::flatbuffers::Vector<uint8_t> *pad_const() const { return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_PAD_CONST); } - tosa::DType type() const { - return static_cast<tosa::DType>(GetField<uint32_t>(VT_TYPE, 0)); - } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PAD_CONST) && verifier.VerifyVector(pad_const()) && - VerifyField<uint32_t>(verifier, VT_TYPE, 4) && verifier.EndTable(); } }; @@ -1033,9 +1014,6 @@ struct PadAttributeBuilder { void add_pad_const(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const) { fbb_.AddOffset(PadAttribute::VT_PAD_CONST, pad_const); } - void add_type(tosa::DType type) { - fbb_.AddElement<uint32_t>(PadAttribute::VT_TYPE, static_cast<uint32_t>(type), 0); - } explicit PadAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1049,24 +1027,20 @@ struct PadAttributeBuilder { inline ::flatbuffers::Offset<PadAttribute> CreatePadAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const = 0, - tosa::DType type = tosa::DType_UNKNOWN) { + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const = 0) { PadAttributeBuilder builder_(_fbb); - builder_.add_type(type); builder_.add_pad_const(pad_const); return builder_.Finish(); } inline ::flatbuffers::Offset<PadAttribute> CreatePadAttributeDirect( ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<uint8_t> *pad_const = nullptr, - tosa::DType type = tosa::DType_UNKNOWN) { + const std::vector<uint8_t> *pad_const = nullptr) { if (pad_const) { _fbb.ForceVectorAlignment(pad_const->size(), sizeof(uint8_t), 8); } auto pad_const__ = pad_const ? _fbb.CreateVector<uint8_t>(*pad_const) : 0; return tosa::CreatePadAttribute( _fbb, - pad_const__, - type); + pad_const__); } struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -1205,8 +1179,7 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef ClampAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_MIN_VAL = 4, - VT_MAX_VAL = 6, - VT_TYPE = 8 + VT_MAX_VAL = 6 }; const ::flatbuffers::Vector<uint8_t> *min_val() const { return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MIN_VAL); @@ -1214,16 +1187,12 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector<uint8_t> *max_val() const { return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MAX_VAL); } - tosa::DType type() const { - return static_cast<tosa::DType>(GetField<uint32_t>(VT_TYPE, 0)); - } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN_VAL) && verifier.VerifyVector(min_val()) && VerifyOffset(verifier, VT_MAX_VAL) && verifier.VerifyVector(max_val()) && - VerifyField<uint32_t>(verifier, VT_TYPE, 4) && verifier.EndTable(); } }; @@ -1238,9 +1207,6 @@ struct ClampAttributeBuilder { void add_max_val(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val) { fbb_.AddOffset(ClampAttribute::VT_MAX_VAL, max_val); } - void add_type(tosa::DType type) { - fbb_.AddElement<uint32_t>(ClampAttribute::VT_TYPE, static_cast<uint32_t>(type), 0); - } explicit ClampAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1255,10 +1221,8 @@ struct ClampAttributeBuilder { inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> min_val = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val = 0, - tosa::DType type = tosa::DType_UNKNOWN) { + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val = 0) { ClampAttributeBuilder builder_(_fbb); - builder_.add_type(type); builder_.add_max_val(max_val); builder_.add_min_val(min_val); return builder_.Finish(); @@ -1267,8 +1231,7 @@ inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttribute( inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttributeDirect( ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<uint8_t> *min_val = nullptr, - const std::vector<uint8_t> *max_val = nullptr, - tosa::DType type = tosa::DType_UNKNOWN) { + const std::vector<uint8_t> *max_val = nullptr) { if (min_val) { _fbb.ForceVectorAlignment(min_val->size(), sizeof(uint8_t), 8); } auto min_val__ = min_val ? _fbb.CreateVector<uint8_t>(*min_val) : 0; if (max_val) { _fbb.ForceVectorAlignment(max_val->size(), sizeof(uint8_t), 8); } @@ -1276,8 +1239,7 @@ inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttributeDirect( return tosa::CreateClampAttribute( _fbb, min_val__, - max_val__, - type); + max_val__); } struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index f5f9e58..c09a47d 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -16,9 +16,9 @@ #ifndef _TOSA_SERIALIZATION_HANDLER_H #define _TOSA_SERIALIZATION_HANDLER_H #include "attribute.h" +#include "cfloat.h" #include "flatbuffers/idl.h" #include "flatbuffers/util.h" -#include "float_utils.h" #include "numpy_utils.h" #include "tosa_generated.h" #include <cstdint> @@ -27,8 +27,8 @@ #include <vector> // Keep version number in sync with the version default value with schema/tosa.fbs -#define TOSA_VERSION_MAJOR 0 -#define TOSA_VERSION_MINOR 100 +#define TOSA_VERSION_MAJOR 1 +#define TOSA_VERSION_MINOR 1 #define TOSA_VERSION_PATCH 0 #define TOSA_VERSION_DRAFT true #define TENSOR_BUFFER_FORCE_ALIGNMENT 8 @@ -412,9 +412,9 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. - static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out); @@ -425,9 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out); static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out); static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 298907e..34178c5 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np -import struct +from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2 from enum import IntEnum, unique from tosa import ( TosaGraph, @@ -31,8 +31,8 @@ import tosa.DType as TosaDType import tosa.Op as TosaOp # Keep version number in sync with the version default value with schema/tosa.fbs -TOSA_VERSION_MAJOR = 0 -TOSA_VERSION_MINOR = 100 +TOSA_VERSION_MAJOR = 1 +TOSA_VERSION_MINOR = 1 TOSA_VERSION_PATCH = 0 TOSA_VERSION_DRAFT = True TOSA_VERSION = [ @@ -190,7 +190,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddAccType, acc_type)) def TransposeConvAttribute( - self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type + self, outpad, stride, input_zp, weight_zp, local_bound, acc_type ): from tosa import TransposeConvAttribute as a, Attribute @@ -199,13 +199,12 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddOutPad, outpad)) self.intvecs.append((a.AddStride, stride)) - self.intvecs.append((a.AddOutputShape, output_shape)) self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) self.bools.append((a.AddLocalBound, local_bound)) self.ints.append((a.AddAccType, acc_type)) - def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype): + def PadAttribute(self, serializer_builder, pad_const_val_as_bytes): from tosa import PadAttribute as a, Attribute self.utype = Attribute.Attribute().PadAttribute @@ -217,7 +216,6 @@ class TosaSerializerAttribute(TosaSerializerUnion): ) self.floats.append((a.AddPadConst, serialized_pad_const_val)) - self.ints.append((a.AddType, dtype)) def AxisAttribute(self, axis): from tosa import AxisAttribute as a, Attribute @@ -238,9 +236,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.int16vecs.append((a.AddBorder, border)) self.ints.append((a.AddMode, mode)) - def ClampAttribute( - self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype - ): + def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes): from tosa import ClampAttribute as a, Attribute self.utype = Attribute.Attribute().ClampAttribute @@ -256,7 +252,6 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.floats.append((a.AddMinVal, serialized_min_val)) self.floats.append((a.AddMaxVal, serialized_max_val)) - self.ints.append((a.AddType, dtype)) def RescaleAttribute( self, @@ -397,13 +392,14 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if ( - dtype == DType.FP32 - or dtype == DType.BF16 - or dtype == DType.FP8E4M3 - or dtype == DType.FP8E5M2 - ): + if dtype == DType.FP32: fntype = np.float32 + elif dtype == DType.BF16: + fntype = bfloat16 + elif dtype == DType.FP8E4M3: + fntype = float8_e4m3fn + elif dtype == DType.FP8E5M2: + fntype = float8_e5m2 elif dtype == DType.FP16: fntype = np.float16 else: @@ -948,35 +944,19 @@ class TosaSerializer: np_arr = np.array(data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP32: - # for val in data: - # b = struct.pack("!f", val) - # u8_data.extend([b[3], b[2], b[1], b[0]]) np_arr = np.array(data, dtype=np.float32) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.BF16: - for val in data: - # convert val to little endian byte arrays b - b = struct.pack("<f", val) - # val => [ b[3], b[2], b[1], b[0] ] - # keep only most significant 2 bytes for bf16 - # in little endian ordering - u8_data.extend([b[2], b[3]]) + np_arr = np.array(data, dtype=bfloat16) + u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP8E4M3: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8) + u8_data.append(val_f8) elif dtype == DType.FP8E5M2: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8) + u8_data.append(val_f8) elif dtype == TosaDType.DType: # Serialize DType enum data as uint8 bytes for val in data: diff --git a/python/tosa/ClampAttribute.py b/python/tosa/ClampAttribute.py index 1189acb..40254ec 100644 --- a/python/tosa/ClampAttribute.py +++ b/python/tosa/ClampAttribute.py @@ -82,15 +82,8 @@ class ClampAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 - # ClampAttribute - def Type(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) - return 0 - def ClampAttributeStart(builder): - builder.StartObject(3) + builder.StartObject(2) def Start(builder): ClampAttributeStart(builder) @@ -104,7 +97,7 @@ def AddMinVal(builder, minVal): def ClampAttributeStartMinValVector(builder, numElems): return builder.StartVector(1, numElems, 1) -def StartMinValVector(builder, numElems: int) -> int: +def StartMinValVector(builder, numElems): return ClampAttributeStartMinValVector(builder, numElems) def ClampAttributeAddMaxVal(builder, maxVal): @@ -116,15 +109,9 @@ def AddMaxVal(builder, maxVal): def ClampAttributeStartMaxValVector(builder, numElems): return builder.StartVector(1, numElems, 1) -def StartMaxValVector(builder, numElems: int) -> int: +def StartMaxValVector(builder, numElems): return ClampAttributeStartMaxValVector(builder, numElems) -def ClampAttributeAddType(builder, type): - builder.PrependUint32Slot(2, type, 0) - -def AddType(builder, type): - ClampAttributeAddType(builder, type) - def ClampAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py index dfa75dc..1deca59 100644 --- a/python/tosa/ConvAttribute.py +++ b/python/tosa/ConvAttribute.py @@ -152,7 +152,7 @@ def AddPad(builder, pad): def ConvAttributeStartPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartPadVector(builder, numElems: int) -> int: +def StartPadVector(builder, numElems): return ConvAttributeStartPadVector(builder, numElems) def ConvAttributeAddStride(builder, stride): @@ -164,7 +164,7 @@ def AddStride(builder, stride): def ConvAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartStrideVector(builder, numElems: int) -> int: +def StartStrideVector(builder, numElems): return ConvAttributeStartStrideVector(builder, numElems) def ConvAttributeAddDilation(builder, dilation): @@ -176,7 +176,7 @@ def AddDilation(builder, dilation): def ConvAttributeStartDilationVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartDilationVector(builder, numElems: int) -> int: +def StartDilationVector(builder, numElems): return ConvAttributeStartDilationVector(builder, numElems) def ConvAttributeAddInputZp(builder, inputZp): diff --git a/python/tosa/CustomAttribute.py b/python/tosa/CustomAttribute.py index db35dca..4c1c477 100644 --- a/python/tosa/CustomAttribute.py +++ b/python/tosa/CustomAttribute.py @@ -96,7 +96,7 @@ def AddImplementationAttrs(builder, implementationAttrs): def CustomAttributeStartImplementationAttrsVector(builder, numElems): return builder.StartVector(1, numElems, 1) -def StartImplementationAttrsVector(builder, numElems: int) -> int: +def StartImplementationAttrsVector(builder, numElems): return CustomAttributeStartImplementationAttrsVector(builder, numElems) def CustomAttributeEnd(builder): diff --git a/python/tosa/PadAttribute.py b/python/tosa/PadAttribute.py index c4084dc..8adf9f7 100644 --- a/python/tosa/PadAttribute.py +++ b/python/tosa/PadAttribute.py @@ -55,15 +55,8 @@ class PadAttribute(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) return o == 0 - # PadAttribute - def Type(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) - return 0 - def PadAttributeStart(builder): - builder.StartObject(2) + builder.StartObject(1) def Start(builder): PadAttributeStart(builder) @@ -77,15 +70,9 @@ def AddPadConst(builder, padConst): def PadAttributeStartPadConstVector(builder, numElems): return builder.StartVector(1, numElems, 1) -def StartPadConstVector(builder, numElems: int) -> int: +def StartPadConstVector(builder, numElems): return PadAttributeStartPadConstVector(builder, numElems) -def PadAttributeAddType(builder, type): - builder.PrependUint32Slot(1, type, 0) - -def AddType(builder, type): - PadAttributeAddType(builder, type) - def PadAttributeEnd(builder): return builder.EndObject() diff --git a/python/tosa/PoolAttribute.py b/python/tosa/PoolAttribute.py index c13e038..831d43b 100644 --- a/python/tosa/PoolAttribute.py +++ b/python/tosa/PoolAttribute.py @@ -145,7 +145,7 @@ def AddPad(builder, pad): def PoolAttributeStartPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartPadVector(builder, numElems: int) -> int: +def StartPadVector(builder, numElems): return PoolAttributeStartPadVector(builder, numElems) def PoolAttributeAddKernel(builder, kernel): @@ -157,7 +157,7 @@ def AddKernel(builder, kernel): def PoolAttributeStartKernelVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartKernelVector(builder, numElems: int) -> int: +def StartKernelVector(builder, numElems): return PoolAttributeStartKernelVector(builder, numElems) def PoolAttributeAddStride(builder, stride): @@ -169,7 +169,7 @@ def AddStride(builder, stride): def PoolAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartStrideVector(builder, numElems: int) -> int: +def StartStrideVector(builder, numElems): return PoolAttributeStartStrideVector(builder, numElems) def PoolAttributeAddInputZp(builder, inputZp): diff --git a/python/tosa/ResizeAttribute.py b/python/tosa/ResizeAttribute.py index 96bfa56..44f7d31 100644 --- a/python/tosa/ResizeAttribute.py +++ b/python/tosa/ResizeAttribute.py @@ -131,7 +131,7 @@ def AddScale(builder, scale): def ResizeAttributeStartScaleVector(builder, numElems): return builder.StartVector(2, numElems, 2) -def StartScaleVector(builder, numElems: int) -> int: +def StartScaleVector(builder, numElems): return ResizeAttributeStartScaleVector(builder, numElems) def ResizeAttributeAddOffset(builder, offset): @@ -143,7 +143,7 @@ def AddOffset(builder, offset): def ResizeAttributeStartOffsetVector(builder, numElems): return builder.StartVector(2, numElems, 2) -def StartOffsetVector(builder, numElems: int) -> int: +def StartOffsetVector(builder, numElems): return ResizeAttributeStartOffsetVector(builder, numElems) def ResizeAttributeAddBorder(builder, border): @@ -155,7 +155,7 @@ def AddBorder(builder, border): def ResizeAttributeStartBorderVector(builder, numElems): return builder.StartVector(2, numElems, 2) -def StartBorderVector(builder, numElems: int) -> int: +def StartBorderVector(builder, numElems): return ResizeAttributeStartBorderVector(builder, numElems) def ResizeAttributeAddMode(builder, mode): diff --git a/python/tosa/TableAttribute.py b/python/tosa/TableAttribute.py index 6caa1f2..04193fa 100644 --- a/python/tosa/TableAttribute.py +++ b/python/tosa/TableAttribute.py @@ -70,7 +70,7 @@ def AddTable(builder, table): def TableAttributeStartTableVector(builder, numElems): return builder.StartVector(2, numElems, 2) -def StartTableVector(builder, numElems: int) -> int: +def StartTableVector(builder, numElems): return TableAttributeStartTableVector(builder, numElems) def TableAttributeEnd(builder): diff --git a/python/tosa/TosaBasicBlock.py b/python/tosa/TosaBasicBlock.py index b31f455..30ad0ee 100644 --- a/python/tosa/TosaBasicBlock.py +++ b/python/tosa/TosaBasicBlock.py @@ -146,7 +146,7 @@ def AddOperators(builder, operators): def TosaBasicBlockStartOperatorsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartOperatorsVector(builder, numElems: int) -> int: +def StartOperatorsVector(builder, numElems): return TosaBasicBlockStartOperatorsVector(builder, numElems) def TosaBasicBlockAddTensors(builder, tensors): @@ -158,7 +158,7 @@ def AddTensors(builder, tensors): def TosaBasicBlockStartTensorsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartTensorsVector(builder, numElems: int) -> int: +def StartTensorsVector(builder, numElems): return TosaBasicBlockStartTensorsVector(builder, numElems) def TosaBasicBlockAddInputs(builder, inputs): @@ -170,7 +170,7 @@ def AddInputs(builder, inputs): def TosaBasicBlockStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartInputsVector(builder, numElems: int) -> int: +def StartInputsVector(builder, numElems): return TosaBasicBlockStartInputsVector(builder, numElems) def TosaBasicBlockAddOutputs(builder, outputs): @@ -182,7 +182,7 @@ def AddOutputs(builder, outputs): def TosaBasicBlockStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartOutputsVector(builder, numElems: int) -> int: +def StartOutputsVector(builder, numElems): return TosaBasicBlockStartOutputsVector(builder, numElems) def TosaBasicBlockEnd(builder): diff --git a/python/tosa/TosaGraph.py b/python/tosa/TosaGraph.py index 84b51a7..520372b 100644 --- a/python/tosa/TosaGraph.py +++ b/python/tosa/TosaGraph.py @@ -85,7 +85,7 @@ def AddRegions(builder, regions): def TosaGraphStartRegionsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartRegionsVector(builder, numElems: int) -> int: +def StartRegionsVector(builder, numElems): return TosaGraphStartRegionsVector(builder, numElems) def TosaGraphEnd(builder): diff --git a/python/tosa/TosaOperator.py b/python/tosa/TosaOperator.py index 2b889ad..19f2d2c 100644 --- a/python/tosa/TosaOperator.py +++ b/python/tosa/TosaOperator.py @@ -125,7 +125,7 @@ def AddInputs(builder, inputs): def TosaOperatorStartInputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartInputsVector(builder, numElems: int) -> int: +def StartInputsVector(builder, numElems): return TosaOperatorStartInputsVector(builder, numElems) def TosaOperatorAddOutputs(builder, outputs): @@ -137,7 +137,7 @@ def AddOutputs(builder, outputs): def TosaOperatorStartOutputsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartOutputsVector(builder, numElems: int) -> int: +def StartOutputsVector(builder, numElems): return TosaOperatorStartOutputsVector(builder, numElems) def TosaOperatorEnd(builder): diff --git a/python/tosa/TosaRegion.py b/python/tosa/TosaRegion.py index 7fd6e3c..80829da 100644 --- a/python/tosa/TosaRegion.py +++ b/python/tosa/TosaRegion.py @@ -81,7 +81,7 @@ def AddBlocks(builder, blocks): def TosaRegionStartBlocksVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartBlocksVector(builder, numElems: int) -> int: +def StartBlocksVector(builder, numElems): return TosaRegionStartBlocksVector(builder, numElems) def TosaRegionEnd(builder): diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py index 3fb9f86..1311aac 100644 --- a/python/tosa/TosaTensor.py +++ b/python/tosa/TosaTensor.py @@ -138,7 +138,7 @@ def AddShape(builder, shape): def TosaTensorStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartShapeVector(builder, numElems: int) -> int: +def StartShapeVector(builder, numElems): return TosaTensorStartShapeVector(builder, numElems) def TosaTensorAddType(builder, type): @@ -156,7 +156,7 @@ def AddData(builder, data): def TosaTensorStartDataVector(builder, numElems): return builder.StartVector(1, numElems, 1) -def StartDataVector(builder, numElems: int) -> int: +def StartDataVector(builder, numElems): return TosaTensorStartDataVector(builder, numElems) def TosaTensorAddVariable(builder, variable): diff --git a/python/tosa/TransposeAttribute.py b/python/tosa/TransposeAttribute.py index 71cfdf0..5aa23e2 100644 --- a/python/tosa/TransposeAttribute.py +++ b/python/tosa/TransposeAttribute.py @@ -70,7 +70,7 @@ def AddPerms(builder, perms): def TransposeAttributeStartPermsVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartPermsVector(builder, numElems: int) -> int: +def StartPermsVector(builder, numElems): return TransposeAttributeStartPermsVector(builder, numElems) def TransposeAttributeEnd(builder): diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py index e5397a8..2f7cdc7 100644 --- a/python/tosa/TransposeConvAttribute.py +++ b/python/tosa/TransposeConvAttribute.py @@ -83,62 +83,35 @@ class TransposeConvAttribute(object): return o == 0 # TransposeConvAttribute - def OutputShape(self, j): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - if o != 0: - a = self._tab.Vector(o) - return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) - return 0 - - # TransposeConvAttribute - def OutputShapeAsNumpy(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - if o != 0: - return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o) - return 0 - - # TransposeConvAttribute - def OutputShapeLength(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - if o != 0: - return self._tab.VectorLen(o) - return 0 - - # TransposeConvAttribute - def OutputShapeIsNone(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - return o == 0 - - # TransposeConvAttribute def InputZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 # TransposeConvAttribute def WeightZp(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) return 0 # TransposeConvAttribute def LocalBound(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) if o != 0: return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False # TransposeConvAttribute def AccType(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) if o != 0: return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) return 0 def TransposeConvAttributeStart(builder): - builder.StartObject(7) + builder.StartObject(6) def Start(builder): TransposeConvAttributeStart(builder) @@ -152,7 +125,7 @@ def AddOutPad(builder, outPad): def TransposeConvAttributeStartOutPadVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartOutPadVector(builder, numElems: int) -> int: +def StartOutPadVector(builder, numElems): return TransposeConvAttributeStartOutPadVector(builder, numElems) def TransposeConvAttributeAddStride(builder, stride): @@ -164,41 +137,29 @@ def AddStride(builder, stride): def TransposeConvAttributeStartStrideVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartStrideVector(builder, numElems: int) -> int: +def StartStrideVector(builder, numElems): return TransposeConvAttributeStartStrideVector(builder, numElems) -def TransposeConvAttributeAddOutputShape(builder, outputShape): - builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(outputShape), 0) - -def AddOutputShape(builder, outputShape): - TransposeConvAttributeAddOutputShape(builder, outputShape) - -def TransposeConvAttributeStartOutputShapeVector(builder, numElems): - return builder.StartVector(4, numElems, 4) - -def StartOutputShapeVector(builder, numElems: int) -> int: - return TransposeConvAttributeStartOutputShapeVector(builder, numElems) - def TransposeConvAttributeAddInputZp(builder, inputZp): - builder.PrependInt32Slot(3, inputZp, 0) + builder.PrependInt32Slot(2, inputZp, 0) def AddInputZp(builder, inputZp): TransposeConvAttributeAddInputZp(builder, inputZp) def TransposeConvAttributeAddWeightZp(builder, weightZp): - builder.PrependInt32Slot(4, weightZp, 0) + builder.PrependInt32Slot(3, weightZp, 0) def AddWeightZp(builder, weightZp): TransposeConvAttributeAddWeightZp(builder, weightZp) def TransposeConvAttributeAddLocalBound(builder, localBound): - builder.PrependBoolSlot(5, localBound, 0) + builder.PrependBoolSlot(4, localBound, 0) def AddLocalBound(builder, localBound): TransposeConvAttributeAddLocalBound(builder, localBound) def TransposeConvAttributeAddAccType(builder, accType): - builder.PrependUint32Slot(6, accType, 0) + builder.PrependUint32Slot(5, accType, 0) def AddAccType(builder, accType): TransposeConvAttributeAddAccType(builder, accType) diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 7b5948b..cad6db7 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -176,7 +176,6 @@ table ConvAttribute { table TransposeConvAttribute { out_pad: [int32]; stride: [int32]; - output_shape: [int32]; input_zp: int32; weight_zp: int32; local_bound: bool; @@ -185,7 +184,6 @@ table TransposeConvAttribute { table PadAttribute { pad_const: [ubyte] (force_align: 8); - type: DType; } table AxisAttribute { @@ -202,7 +200,6 @@ table ResizeAttribute { table ClampAttribute { min_val: [ubyte] (force_align: 8); max_val: [ubyte] (force_align: 8); - type: DType; } table RescaleAttribute { diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index e4171d7..7cf5f94 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3 while (isspace(*ptr)) ptr++; + // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but + // default NumPy does not understand this notation, which causes trouble + // when other code tries to open this file. + // To avoid this, '|u1' notation is used when the file is written, and the uint8 + // data is viewed as float8_e5m2 later when the file is read. + if (!strcmp(dtype_str, "'<f1'")) + dtype_str = "'|u1'"; + if (strcmp(ptr, dtype_str)) { return FILE_TYPE_MISMATCH; @@ -430,6 +438,13 @@ NumpyUtilities::NPError memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1); headerPos += sizeof(NUMPY_HEADER_STR) - 1; + // NumPy does not understand float8_e5m2, so change it to uint8 type, so that + // Python can read .npy files. + if (!strcmp(dtype_str, "'<f1'")) + { + dtype_str = "'|u1'"; + } + // Output the format dictionary // Hard-coded for I32 for now headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, @@ -438,7 +453,19 @@ NumpyUtilities::NPError // Add shape contents (if any - as this will be empty for rank 0) for (i = 0; i < shape.size(); i++) { - headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); + // Output NumPy file from tosa_refmodel_sut_run generates the shape information + // without a trailing comma when the rank is greater than 1. + if (i == 0) + { + if (shape.size() == 1) + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]); + else + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]); + } + else + { + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]); + } } // Close off the dictionary diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 85625cd..74f66d8 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -19,9 +19,6 @@ #include <iostream> using namespace tosa; -using fp8e4m3 = tosa::float_t<int8_t, 4, true, true, false>; -using fp8e5m2 = tosa::float_t<int8_t, 5, true, true, true>; - TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector<int32_t>* shape, DType dtype, @@ -750,45 +747,41 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf) } } -tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->bf16 by ignoring the least significant 16 bits out.clear(); for (auto val : in) { - uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val); - uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF; - uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF; - // little endian: byte2 followed by byte3 - out.push_back(f32_byte2); - out.push_back(f32_byte3); + uint8_t bf16_byte0 = val.bits() & 0xFF; + uint8_t bf16_byte1 = (val.bits() >> 8) & 0xFF; + out.push_back(bf16_byte0); + out.push_back(bf16_byte1); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->FP8E4M3 before converting to unint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast<fp8e4m3>(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->FP8E5M2 before converting to uint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast<fp8e5m2>(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); @@ -944,11 +937,9 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, - uint32_t out_size, - std::vector<float>& out) +tosa_err_t + TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out) { - // Note: bf16 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int16_t)) { @@ -959,22 +950,21 @@ tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& for (uint32_t i = 0; i < out_size; i++) { - uint32_t f32_byte2 = in[i * sizeof(int16_t)]; - uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; - uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + uint8_t bf16_byte0 = in[i * sizeof(int16_t)]; + uint8_t bf16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = (bf16_byte0) + (bf16_byte1 << 8); - // Reinterpret u32 bytes as fp32 - float val_f32 = *(float*)&val_u32; - out.push_back(val_f32); + // Reinterpret u16 bytes as bf16 + bf16 val_bf16 = *(bf16*)&val_u16; + out.push_back(val_bf16); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, - std::vector<float>& out) + std::vector<fp8e4m3>& out) { - // Note: FP8E4M3 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int8_t)) { @@ -985,17 +975,16 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_ for (uint32_t i = 0; i < out_size; i++) { - int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); - auto f8 = fp8e4m3::from_bits(bits); - float val_f32 = static_cast<float>(f8); - out.push_back(val_f32); + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e4m3::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, - std::vector<float>& out) + std::vector<fp8e5m2>& out) { // Note: FP8E5M2 values returned in fp32 type out.clear(); @@ -1008,10 +997,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_ for (uint32_t i = 0; i < out_size; i++) { - int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); - auto f8 = fp8e5m2::from_bits(bits); - float val_f32 = static_cast<float>(f8); - out.push_back(val_f32); + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e5m2::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } @@ -1031,9 +1019,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& for (uint32_t i = 0; i < out_size; i++) { - uint16_t f16_byte0 = in[i * sizeof(int16_t)]; - uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1]; - uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); + uint8_t f16_byte0 = in[i * sizeof(int16_t)]; + uint8_t f16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); // Reinterpret u16 byte as fp16 then convert to fp32 half_float::half val_f16 = *(half_float::half*)&val_u16; diff --git a/third_party/flatbuffers b/third_party/flatbuffers -Subproject 0100f6a5779831fa7a651e4b67ef389a8752bd9 +Subproject 6ff9e90e7e399f3977e99a315856b57c8afe5b4 |