aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/attribute.def13
-rw-r--r--include/cfloat.h861
-rw-r--r--include/float_utils.h533
-rw-r--r--include/numpy_utils.h17
-rw-r--r--include/tosa_generated.h68
-rw-r--r--include/tosa_serialization_handler.h18
-rw-r--r--python/serializer/tosa_serializer.py58
-rw-r--r--python/tosa/ClampAttribute.py19
-rw-r--r--python/tosa/ConvAttribute.py6
-rw-r--r--python/tosa/CustomAttribute.py2
-rw-r--r--python/tosa/PadAttribute.py17
-rw-r--r--python/tosa/PoolAttribute.py6
-rw-r--r--python/tosa/ResizeAttribute.py6
-rw-r--r--python/tosa/TableAttribute.py2
-rw-r--r--python/tosa/TosaBasicBlock.py8
-rw-r--r--python/tosa/TosaGraph.py2
-rw-r--r--python/tosa/TosaOperator.py4
-rw-r--r--python/tosa/TosaRegion.py2
-rw-r--r--python/tosa/TosaTensor.py4
-rw-r--r--python/tosa/TransposeAttribute.py2
-rw-r--r--python/tosa/TransposeConvAttribute.py61
-rw-r--r--schema/tosa.fbs3
-rw-r--r--src/numpy_utils.cpp29
-rw-r--r--src/tosa_serialization_handler.cpp68
m---------third_party/flatbuffers0
26 files changed, 1021 insertions, 790 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5f4f851..679603d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -76,7 +76,7 @@ set(public_headers)
list(APPEND public_headers
include/attribute.h
include/attribute.def
- include/float_utils.h
+ include/cfloat.h
include/numpy_utils.h
include/tosa_generated.h
include/tosa_serialization_handler.h
diff --git a/include/attribute.def b/include/attribute.def
index 30b432d..0e97629 100644
--- a/include/attribute.def
+++ b/include/attribute.def
@@ -43,18 +43,16 @@ DEF_ATTRIBUTE(Conv, 7,
bool, S, local_bound,
DType, S, acc_type)
-DEF_ATTRIBUTE(TransposeConv, 7,
+DEF_ATTRIBUTE(TransposeConv, 6,
int32_t, V, out_pad,
int32_t, V, stride,
- int32_t, V, output_shape,
int32_t, S, input_zp,
int32_t, S, weight_zp,
bool, S, local_bound,
DType, S, acc_type)
-DEF_ATTRIBUTE(Pad, 2,
- uint8_t, V, pad_const,
- DType, S, type)
+DEF_ATTRIBUTE(Pad, 1,
+ uint8_t, V, pad_const)
DEF_ATTRIBUTE(Axis, 1,
int32_t, S, axis)
@@ -65,10 +63,9 @@ DEF_ATTRIBUTE(Resize, 4,
int16_t, V, border,
ResizeMode, S, mode)
-DEF_ATTRIBUTE(Clamp, 3,
+DEF_ATTRIBUTE(Clamp, 2,
uint8_t, V, min_val,
- uint8_t, V, max_val,
- DType, S, type)
+ uint8_t, V, max_val)
DEF_ATTRIBUTE(Rescale, 7,
int32_t, S, input_zp,
diff --git a/include/cfloat.h b/include/cfloat.h
new file mode 100644
index 0000000..cbbe09a
--- /dev/null
+++ b/include/cfloat.h
@@ -0,0 +1,861 @@
+// Copyright (c) 2022-2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef CT_CFLOAT_H
+#define CT_CFLOAT_H
+#include <algorithm>
+#include <cstdint>
+#include <cstring>
+#include <limits>
+#include <type_traits>
+#if defined(__cpp_lib_bit_cast)
+#include <bit>
+#endif // defined(__cpp_lib_bit_cast)
+
+namespace ct
+{
+/// \brief Bitfield specification of the features provided of a specified
+/// floating point type.
+enum class FloatFeatures
+{
+ None = 0x0,
+ HasNaN = 0x1, ///< The type can represent NaN values
+ HasInf = 0x2, ///< The type can represent Infinity
+ HasDenorms = 0x4, ///< The type can represent denormal/subnormal values
+};
+
+constexpr FloatFeatures operator&(const FloatFeatures& a, const FloatFeatures& b)
+{
+ using T = std::underlying_type_t<FloatFeatures>;
+ return static_cast<FloatFeatures>(static_cast<T>(a) & static_cast<T>(b));
+}
+
+constexpr FloatFeatures operator|(const FloatFeatures& a, const FloatFeatures& b)
+{
+ using T = std::underlying_type_t<FloatFeatures>;
+ return static_cast<FloatFeatures>(static_cast<T>(a) | static_cast<T>(b));
+}
+
+constexpr FloatFeatures& operator|=(FloatFeatures& a, const FloatFeatures& b)
+{
+ a = a | b;
+ return a;
+}
+
+namespace float_support
+{
+struct hidden
+{};
+
+/// \brief Get the number of bytes required to store the given number of
+/// bits.
+///
+/// NOTE This is distinct from the number of bytes required to represent
+/// the number of bits - a power of two number of bytes will always be
+/// returned by this method.
+constexpr size_t get_storage_bytes(const size_t n_bits)
+{
+ const size_t n_bytes = (n_bits + 7) / 8;
+ size_t storage_bytes = 1;
+ for (; storage_bytes < n_bytes; storage_bytes <<= 1)
+ ;
+ return storage_bytes;
+}
+
+/// \brief Utility method to convert from an older representation of the
+/// floating-point features to the FloatFeatures bitfield.
+constexpr FloatFeatures get_float_flags(bool has_nan, bool has_denorm, bool has_inf)
+{
+ FloatFeatures r = FloatFeatures::None;
+
+ if (has_nan)
+ r |= FloatFeatures::HasNaN;
+
+ if (has_denorm)
+ r |= FloatFeatures::HasDenorms;
+
+ if (has_inf)
+ r |= FloatFeatures::HasInf;
+
+ return r;
+}
+
+/// \brief Shorthand for all support features
+static constexpr FloatFeatures AllFeats = get_float_flags(true, true, true);
+
+// Map from a number of storage bytes to a suitable storage type
+template <size_t n_bytes>
+struct storage_type;
+
+#define STORAGE_TYPE(T) \
+ template <> \
+ struct storage_type<sizeof(T)> \
+ { \
+ using type = T; \
+ }
+STORAGE_TYPE(int8_t);
+STORAGE_TYPE(int16_t);
+STORAGE_TYPE(int32_t);
+STORAGE_TYPE(int64_t);
+#undef STORAGE_TYPE
+
+template <size_t n_storage_bytes>
+using storage_type_t = typename storage_type<n_storage_bytes>::type;
+
+#if defined(__cpp_lib_bit_cast)
+#define BITCAST_CONSTEXPR constexpr inline
+
+// If bit_cast is available then use it
+
+constexpr inline int32_t get_bits(const float& f)
+{
+ return std::bit_cast<int32_t>(f);
+}
+constexpr inline float from_bits(const int32_t& i)
+{
+ return std::bit_cast<float>(i);
+}
+
+#else
+#define BITCAST_CONSTEXPR inline
+
+// Otherwise `memcpy` is the safe (non-UB) of achieving the same result
+
+inline int32_t get_bits(const float& f)
+{
+ int32_t i;
+ std::memcpy(&i, &f, sizeof(float));
+ return i;
+}
+
+inline float from_bits(const int32_t& i)
+{
+ float f;
+ std::memcpy(&f, &i, sizeof(float));
+ return f;
+}
+#endif
+
+} // namespace float_support
+
+/// \brief Overflow mode for narrowing floating-point casts.
+///
+/// Determine the behaviour for values which cannot be represented by the
+/// destination type.
+enum class OverflowMode
+{
+ Saturate, ///< Map to the largest representable value
+ Overflow ///< Map to infinity (if available) or NaN
+};
+
+/// Functor for casting cfloat_advanced
+///
+/// Specific casting behavior can be specified when constructing the
+/// functor.
+///
+/// By default, OVERFLOW mode is used when the destination type has either
+/// infinity or NaN representations. Otherwise SATURATE mode is used. It is
+/// illegal to specify OVERFLOW mode for a type which has neither infinity
+/// or NaN representations - this will result in a compilation error.
+template <class in_type,
+ class out_type,
+ OverflowMode overflow_mode =
+ (out_type::has_nan || out_type::has_inf) ? OverflowMode::Overflow : OverflowMode::Saturate>
+class cfloat_cast
+{
+ constexpr static FloatFeatures in_feats = in_type::features;
+ constexpr static FloatFeatures out_feats = out_type::features;
+ constexpr static size_t in_bits = in_type::n_bits;
+ constexpr static size_t in_exp_bits = in_type::n_exponent_bits;
+ constexpr static size_t out_bits = out_type::n_bits;
+ constexpr static size_t out_exp_bits = out_type::n_exponent_bits;
+
+public:
+ constexpr cfloat_cast()
+ {
+ // SATURATE mode MUST be specified if the destination type does not
+ // have either NaN or infinity representations.
+ static_assert(overflow_mode == OverflowMode::Saturate || out_type::has_nan || out_type::has_inf);
+ }
+
+ /// \brief Cast from `in` to the given `out_type`
+ //
+ // This code relies on an understanding of the storage format used by
+ // `cfloat_advanced`. See the documentation of that class for further
+ // details.
+ constexpr out_type operator()(const in_type& in) const
+ {
+ // Shortcut for types which differ only in the number of significand
+ // bits, and where the output type is wider than the input type. For
+ // example, bfloat16 and binary32.
+ if constexpr (in_exp_bits == out_exp_bits && out_bits >= in_bits && in_feats == out_feats)
+ {
+ return out_type::from_bits(static_cast<typename out_type::storage_t>(in.bits()) << (out_bits - in_bits));
+ }
+
+ // Get initial values for the new floating point type
+ const bool sign_bit = in.sign();
+ int64_t new_exponent_bits = 0;
+ uint64_t new_significand = 0;
+
+ if (in.is_nan() || in.is_infinity())
+ {
+ // The mapping of infinity to the destination type depends upon
+ // the overflow mode and the features of the destination type.
+ // OVERFLOW mode is the "expected" behaviour, in which exception
+ // values (NaN and infinity) map to themselves in the
+ // destination type (assuming they exist). In SATURATION mode,
+ // infinity maps to the largest absolute value of the
+ // destination type _even if_ an infinity encoding is available.
+ // See the FP8 specification document.
+ //
+ // By default, exceptional values are encoded with an all-1
+ // exponent field.
+ new_exponent_bits = (UINT64_C(1) << out_exp_bits) - 1;
+
+ if (in.is_nan())
+ {
+ // NaN always maps to NaN if it's available.
+ //
+ // NB: if the type has both NaN AND Infinity support, then
+ // the entirety of the significand can be used to encode
+ // different values of NaN (excepting significand = 0,
+ // which is reserved for infinity). This makes it possible
+ // to encode both quiet and signalling varieties.
+ // Generally, the LSB of the significand represents "not
+ // quiet". However, when there is only 1 NaN encoding
+ // (which is generally the case when infinity is not
+ // supported), then there cannot be separate quiet and
+ // signalling varieties of NaN.
+ if constexpr (out_type::has_inf)
+ {
+ // Copy across the `not_quiet bit`; set the LSB.
+ // Don't attempt to copy across any of the rest of
+ // the payload.
+ new_significand = 0x1 | (((in.significand() >> (in_type::n_significand_bits - 1)) & 1)
+ << out_type::n_significand_bits);
+ }
+ else
+ {
+ new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
+ }
+ }
+ else if constexpr (overflow_mode == OverflowMode::Saturate)
+ {
+ // In SATURATE mode, infinity in the input maps to the
+ // largest absolute value in the output type; even if
+ // infinity is available. This is in compliance with Table 3
+ // of the FP8 specification.
+ return out_type::max(sign_bit);
+ }
+ else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Overflow)
+ {
+ // In OVERFLOW mode, infinities in the input type map to NaN
+ // in the output type, if infinity is not available.
+ new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
+ }
+ }
+ else if (!in.is_zero())
+ {
+ const int64_t this_exponent_bits = in.exponent_bits();
+ {
+ constexpr int64_t exponent_rebias = out_type::exponent_bias - in_type::exponent_bias;
+ new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1);
+ }
+ new_significand = in.significand() << (64 - in_type::n_significand_bits);
+
+ // Normalise subnormals
+ if (this_exponent_bits == 0)
+ {
+ // Shift the most-significant 1 out of the magnitude to
+ // convert it to a significand. Modify the exponent
+ // accordingly.
+ uint8_t shift = __builtin_clzl(new_significand) + 1;
+ new_exponent_bits -= shift;
+ new_significand <<= shift;
+ }
+
+ // Apply overflow to out-of-range values; this must occur before
+ // rounding, as out-of-range values could be rounded down to the
+ // largest representable value.
+ if constexpr (overflow_mode == OverflowMode::Overflow)
+ {
+ // Determine the maximum value of exponent, and unrounded
+ // significand.
+ constexpr bool inf_and_nan = out_type::has_nan && out_type::has_inf;
+ constexpr int64_t max_exp_bits = (INT64_C(1) << out_exp_bits) - (inf_and_nan ? 2 : 1);
+ constexpr uint64_t max_significand =
+ ((UINT64_C(1) << out_type::n_significand_bits) - (inf_and_nan ? 1 : 2))
+ << (64 - out_type::n_significand_bits);
+
+ // If the exponent is strictly larger than the largest
+ // possible, or the exponent is equal to the largest
+ // possible AND the (unrounded) significand is strictly
+ // larger than the largest possible then return an
+ // appropriate overflow value.
+ if (new_exponent_bits > max_exp_bits ||
+ (new_exponent_bits == max_exp_bits && new_significand > max_significand))
+ {
+ if constexpr (out_type::has_inf)
+ return out_type::infinity(sign_bit);
+ else
+ return out_type::NaN();
+ }
+ }
+
+ // Align the significand for the output type
+ uint32_t shift = 64 - out_type::n_significand_bits;
+ const bool other_is_subnormal = new_exponent_bits <= 0;
+ if (other_is_subnormal)
+ {
+ shift += 1 - new_exponent_bits;
+ new_exponent_bits = 0;
+ }
+
+ const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((UINT64_C(1) << shift) - 1);
+ new_significand = shift == 64 ? 0 : new_significand >> shift;
+
+ // Reinsert the most-significant-one if this is a subnormal
+ // in the output type.
+ new_significand |= (other_is_subnormal ? UINT64_C(1) : 0) << (64 - shift);
+
+ // Apply rounding based on the bits shifted out of the
+ // significand
+ const uint64_t shift_half = UINT64_C(1) << (shift - 1);
+ if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1)))
+ {
+ new_significand += 1;
+
+ // Handle the case that the significand overflowed due
+ // to rounding
+ constexpr uint64_t max_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
+ if (new_significand > max_significand)
+ {
+ new_significand = 0;
+ new_exponent_bits++;
+ }
+ }
+
+ // Saturate or overflow if the value is larger than can be
+ // represented in the output type. This can only occur if the
+ // size of the exponent of the new type is not greater than the
+ // exponent of the old type.
+ if constexpr (out_exp_bits <= in_exp_bits)
+ {
+ constexpr int64_t inf_exp_bits = (INT64_C(1) << out_exp_bits) - 1;
+ if (new_exponent_bits >= inf_exp_bits)
+ {
+ if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Overflow)
+ {
+ // If the output type has a representation of
+ // infinity, and we are in OVERFLOW Mode, then
+ // return infinity.
+ new_exponent_bits = inf_exp_bits;
+ new_significand = 0;
+ }
+ else if constexpr (out_type::has_inf)
+ {
+ // If the output type has a representation of
+ // infinity, and we are in SATURATE mode, then
+ // return the largest representable real number.
+ new_exponent_bits = inf_exp_bits - 1;
+ new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
+ }
+ else if (new_exponent_bits > inf_exp_bits)
+ {
+ if constexpr (overflow_mode == OverflowMode::Overflow)
+ return out_type::NaN();
+ else
+ return out_type::max(sign_bit);
+ }
+ else
+ {
+ constexpr uint64_t max_significand =
+ (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1);
+ if (new_significand > max_significand)
+ {
+ if constexpr (overflow_mode == OverflowMode::Saturate)
+ new_significand = max_significand;
+ else
+ return out_type::NaN();
+ }
+ }
+ }
+ }
+ }
+
+ return out_type::from_bits(sign_bit, new_exponent_bits, new_significand);
+ }
+};
+
+/// \brief Bit-accurate representation storage of IEEE754 compliant and
+/// derived floating point types.
+///
+/// Template parameters allow for specification of the number of bits, the
+/// number of exponent bits, and the features of the floating point types.
+/// The number of significand bits is `n_bits - n_exponent_bits - 1`. It is
+/// not possible to represent a signless type, such as FP8 E8M0.
+///
+/// For an imaginary 7-bit type, FP7 E4M2; the storage for various values
+/// given different floating point features is given below:
+///
+/// Value All features No infinity No features
+/// -------------------------- ------------ ----------- -----------
+/// Positive zero +0 00 0000 00 As before As before
+/// Negative zero -0 11 0000 00 As before As before
+/// Positive/negative infinity SS 1111 00 N/A N/A
+/// Signalling NaN SS 1111 01 SS 1111 11 N/A
+/// Quiet NaN SS 1111 11 N/A N/A
+/// Largest normal SS 1110 11 SS 1111 10 SS 1111 11
+/// Smallest normal SS 0001 00 As before SS 0000 01
+/// Largest denormal SS 0000 11 SS 0000 11 N/A
+///
+/// Note that the sign bit is extended to fill the storage type.
+template <size_t _n_bits, size_t n_exp_bits, FloatFeatures Feats = float_support::AllFeats>
+class cfloat_advanced
+{
+public:
+ using storage_t = float_support::storage_type_t<float_support::get_storage_bytes(_n_bits)>;
+
+ static constexpr size_t n_bits = _n_bits;
+ static constexpr size_t n_exponent_bits = n_exp_bits;
+ static constexpr size_t n_significand_bits = n_bits - (1 + n_exp_bits);
+ static constexpr int64_t exponent_bias = (INT64_C(1) << (n_exp_bits - 1)) - 1;
+
+ static constexpr FloatFeatures features = Feats;
+ static constexpr bool has_nan = (Feats & FloatFeatures::HasNaN) != FloatFeatures::None;
+ static constexpr bool has_inf = (Feats & FloatFeatures::HasInf) != FloatFeatures::None;
+ static constexpr bool has_denorms = (Feats & FloatFeatures::HasDenorms) != FloatFeatures::None;
+
+ /// \brief Construct a floating point type with the given bit
+ /// representation.
+ static constexpr cfloat_advanced from_bits(storage_t bits)
+ {
+ return cfloat_advanced(float_support::hidden(), bits);
+ }
+
+ /// \brief Construct a float from the given sign, exponent and
+ /// significand bits.
+ static constexpr cfloat_advanced from_bits(bool pm, storage_t e, storage_t s)
+ {
+ storage_t bits = pm ? -1 : 0;
+
+ bits <<= n_exp_bits;
+ bits |= e;
+
+ bits <<= n_significand_bits;
+ if (has_denorms || e)
+ bits |= s;
+
+ return cfloat_advanced(float_support::hidden(), bits);
+ }
+
+ /// \brief (Hidden) Construct a float type from a given bit pattern
+ constexpr cfloat_advanced(const float_support::hidden&, storage_t bits)
+ : m_data(bits)
+ {}
+
+ constexpr cfloat_advanced()
+ : m_data(0)
+ {}
+ constexpr cfloat_advanced(const cfloat_advanced& other)
+ : m_data(other.m_data)
+ {}
+
+ constexpr cfloat_advanced& operator=(const cfloat_advanced& other)
+ {
+ this->m_data = other.m_data;
+ return *this;
+ }
+
+ constexpr cfloat_advanced& operator=(cfloat_advanced&& other)
+ {
+ this->m_data = other.m_data;
+ return *this;
+ }
+
+ /// \brief Get a NaN representation
+ static constexpr cfloat_advanced NaN()
+ {
+ static_assert(has_nan);
+
+ // NaN is always encoded with all 1s in the exponent.
+ // If Inf exists, then NaN is encoded as a non-zero significand; if
+ // Inf doesn't exist then NaN is encoded as all ones in the
+ // significand.
+ constexpr uint64_t exp_bits = (UINT64_C(1) << n_exponent_bits) - 1;
+ constexpr uint64_t sig_bits = has_inf ? 1 : (UINT64_C(1) << n_significand_bits) - 1;
+ return cfloat_advanced::from_bits(false, exp_bits, sig_bits);
+ }
+
+ /// \brief Get a representation of infinity
+ static constexpr cfloat_advanced infinity(const bool& sign)
+ {
+ static_assert(has_inf);
+
+ // Inf is always encoded with all 1s in the exponent, and all zeros
+ // in the significand.
+ return cfloat_advanced::from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, 0);
+ }
+
+ /// \brief Get the largest representable value
+ static constexpr cfloat_advanced max(const bool& sign)
+ {
+ if constexpr (has_nan && has_inf)
+ {
+ // Where we have NaN and Infinity, exponents all `1` corresponds
+ // to some of these values.
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1);
+ }
+ else if constexpr (has_nan || has_inf)
+ {
+ // Where we have either NaN or infinity (but not both),
+ // exponents all `1` AND significand all `1` corresponds to the
+ // special value.
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2);
+ }
+ else
+ {
+ // With no special values to encode, the maximum value is
+ // encoded as all `1`s.
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1);
+ }
+ }
+
+ /// \brief Cast to a different floating point representation.
+ template <size_t out_n_bits, size_t out_n_exp_bits, FloatFeatures OutFeats>
+ constexpr inline operator cfloat_advanced<out_n_bits, out_n_exp_bits, OutFeats>() const
+ {
+ using out_type = cfloat_advanced<out_n_bits, out_n_exp_bits, OutFeats>;
+ return cfloat_cast<cfloat_advanced, out_type>().operator()(*this);
+ }
+
+ /// \brief Convert from a 32-bit floating point value
+ BITCAST_CONSTEXPR
+ cfloat_advanced(const float& f)
+ {
+ // If this format exactly represents the binary32 format then get
+ // the bits from the provided float; otherwise get a binary32
+ // representation and then convert to this format.
+ if constexpr (represents_binary32())
+ m_data = float_support::get_bits(f);
+ else
+ m_data =
+ static_cast<cfloat_advanced<n_bits, n_exp_bits, Feats>>(static_cast<cfloat_advanced<32, 8>>(f)).m_data;
+ }
+
+ /// \brief Cast to a 32-bit floating point value
+ BITCAST_CONSTEXPR operator float() const
+ {
+ // If this format exactly represents the binary32 format then return
+ // a float; otherwise get a binary32 representation and then return
+ // a float.
+ if constexpr (represents_binary32())
+ return float_support::from_bits(m_data);
+ else
+ return static_cast<float>(this->operator cfloat_advanced<32, 8>());
+ }
+
+ /// \brief Return whether this type represents the IEEE754 binary32
+ /// format
+ constexpr static inline bool represents_binary32()
+ {
+ return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && Feats == float_support::AllFeats;
+ }
+
+ constexpr auto operator-() const
+ {
+ constexpr storage_t sign_bits =
+ static_cast<storage_t>(std::numeric_limits<std::make_unsigned_t<storage_t>>::max() << (n_bits - 1));
+ return from_bits(m_data ^ sign_bits);
+ }
+
+ constexpr bool is_subnormal() const
+ {
+ return exponent_bits() == 0 && significand() != 0;
+ }
+
+ constexpr bool is_zero() const
+ {
+ return exponent_bits() == 0 && significand() == 0;
+ }
+
+ constexpr bool is_nan() const
+ {
+ return has_nan && (exponent_bits() == (UINT64_C(1) << n_exponent_bits) - 1) &&
+ ((has_inf && significand()) || (!has_inf && significand() == (UINT64_C(1) << n_significand_bits) - 1));
+ }
+
+ constexpr bool is_infinity() const
+ {
+ return has_inf && ((exponent_bits() == (UINT64_C(1) << n_exponent_bits) - 1) && (significand() == 0));
+ }
+
+ constexpr inline const storage_t& bits() const
+ {
+ return m_data;
+ }
+
+ /// \brief Get the exponent
+ constexpr inline int64_t exponent() const
+ {
+ return std::max<int64_t>(exponent_bits(), INT64_C(1)) - exponent_bias;
+ }
+
+ /// \brief Get the sign bit
+ constexpr inline bool sign() const
+ {
+ return (m_data >> (n_bits - 1)) & 0x1;
+ }
+
+ /// \brief Get the bits from the exponent
+ constexpr inline uint64_t exponent_bits() const
+ {
+ constexpr uint64_t mask = (UINT64_C(1) << n_exp_bits) - 1;
+ return (m_data >> n_significand_bits) & mask;
+ }
+
+ constexpr inline uint64_t significand() const
+ {
+ return m_data & ((UINT64_C(1) << n_significand_bits) - 1);
+ }
+
+ constexpr inline bool operator==(const cfloat_advanced& other) const
+ {
+ return !is_nan() && !other.is_nan() && // Neither operand is NaN
+ ((is_zero() && other.is_zero()) || (m_data == other.m_data));
+ }
+
+ constexpr inline bool operator!=(const cfloat_advanced& other) const
+ {
+ return !(*this == other);
+ }
+
+ constexpr inline cfloat_advanced& operator+=(const cfloat_advanced& rhs)
+ {
+ this->m_data = static_cast<cfloat_advanced>(static_cast<float>(*this) + static_cast<float>(rhs)).bits();
+ return *this;
+ }
+
+private:
+ storage_t m_data = 0;
+};
+
+// This should probably be exported so we can use it elsewhere
+#undef BITCAST_CONSTEXPR
+
+/// \brief Wrapper to maintain API compatibility with older code, which was
+/// limited to power-of-two sizes of floats.
+template <typename storage_t,
+ size_t n_exp_bits,
+ bool has_nan,
+ bool with_denorm,
+ bool with_infinity,
+ std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true>
+using cfloat = cfloat_advanced<sizeof(storage_t) * 8,
+ n_exp_bits,
+ float_support::get_float_flags(has_nan, with_denorm, with_infinity)>;
+
+namespace float_support
+{
+// Pre-C++23 these can't be computed as constexpr, so have to hardcode
+// them
+
+template <int>
+struct digits10; // floor(log10(2) * (digits - 1)
+template <int>
+struct max_digits10; // ceil(log10(2) * digits + 1)
+template <int>
+struct min_exponent10; // floor(log10(2) * min_exponent)
+template <int>
+struct max_exponent10; // floor(log10(2) * max_exponent)
+
+template <>
+struct digits10<8>
+{
+ constexpr static inline int value = 2;
+};
+
+template <>
+struct max_digits10<8>
+{
+ constexpr static inline int value = 4;
+};
+
+template <>
+struct digits10<10>
+{
+ constexpr static inline int value = 2;
+};
+
+template <>
+struct max_digits10<10>
+{
+ constexpr static inline int value = 5;
+};
+
+template <>
+struct digits10<24>
+{
+ constexpr static inline int value = 6;
+};
+
+template <>
+struct max_digits10<24>
+{
+ constexpr static inline int value = 9;
+};
+
+template <>
+struct min_exponent10<-13>
+{
+ constexpr static inline int value = -3;
+};
+
+template <>
+struct max_exponent10<16>
+{
+ constexpr static inline int value = 4;
+};
+
+template <>
+struct min_exponent10<-125>
+{
+ constexpr static inline int value = -37;
+};
+
+template <>
+struct max_exponent10<128>
+{
+ constexpr static inline int value = 38;
+};
+
+template <int d>
+inline constexpr int digits10_v = digits10<d>::value;
+template <int d>
+inline constexpr int max_digits10_v = max_digits10<d>::value;
+
+template <int e>
+inline constexpr int min_exponent10_v = min_exponent10<e>::value;
+
+template <int e>
+inline constexpr int max_exponent10_v = max_exponent10<e>::value;
+
+} // namespace float_support
+
+} // namespace ct
+
+namespace std
+{
+
+template <size_t n_bits, size_t n_exp_bits, ct::FloatFeatures Feats>
+struct is_floating_point<ct::cfloat_advanced<n_bits, n_exp_bits, Feats>> : std::integral_constant<bool, true>
+{};
+
+template <size_t n_bits, size_t n_exp_bits, ct::FloatFeatures Feats>
+class numeric_limits<ct::cfloat_advanced<n_bits, n_exp_bits, Feats>>
+{
+ using this_cfloat = ct::cfloat_advanced<n_bits, n_exp_bits, Feats>;
+
+public:
+ static constexpr bool is_specialized = true;
+
+ static constexpr auto min() noexcept
+ {
+ return this_cfloat::from_bits(false, 1, 0);
+ }
+
+ static constexpr auto max() noexcept
+ {
+ return this_cfloat::max(false);
+ }
+ static constexpr auto lowest() noexcept
+ {
+ return -max();
+ }
+
+ static constexpr int digits = this_cfloat::n_significand_bits + 1;
+ static constexpr int digits10 = ct::float_support::digits10_v<digits>;
+ static constexpr int max_digits10 = ct::float_support::max_digits10_v<digits>;
+
+ static constexpr bool is_signed = true;
+ static constexpr bool is_integer = false;
+ static constexpr bool is_exact = false;
+ static constexpr int radix = 2;
+
+ static constexpr auto epsilon() noexcept
+ {
+ return this_cfloat::from_bits(false, this_cfloat::exponent_bias - this_cfloat::n_significand_bits, 0);
+ }
+
+ static constexpr auto round_error() noexcept
+ {
+ return this_cfloat::from_bits(0, this_cfloat::exponent_bias - 1, 0);
+ }
+
+ static constexpr int min_exponent = (1 - this_cfloat::exponent_bias) + 1;
+ static constexpr int min_exponent10 = ct::float_support::min_exponent10_v<min_exponent>;
+ static constexpr int max_exponent = this_cfloat::exponent_bias + 1;
+ static constexpr int max_exponent10 = ct::float_support::max_exponent10_v<max_exponent>;
+
+ static constexpr bool has_infinity = this_cfloat::has_inf;
+ static constexpr bool has_quiet_NaN = this_cfloat::has_nan && this_cfloat::has_inf;
+ static constexpr bool has_signaling_NaN = this_cfloat::has_nan;
+ static constexpr float_denorm_style has_denorm = this_cfloat::has_denorms ? denorm_present : denorm_absent;
+ static constexpr bool has_denorm_loss = false;
+
+ static constexpr auto infinity() noexcept
+ {
+ if constexpr (this_cfloat::has_inf)
+ {
+ return this_cfloat::infinity(false);
+ }
+ else
+ {
+ return this_cfloat::from_bits(false, 0, 0);
+ }
+ }
+
+ static constexpr auto quiet_NaN() noexcept
+ {
+ const uint64_t exp_bits = (UINT64_C(1) << this_cfloat::n_exponent_bits) - 1;
+ const uint64_t sig_bits = this_cfloat::has_inf ? (UINT64_C(1) << (this_cfloat::n_significand_bits - 1)) | 1
+ : (UINT64_C(1) << this_cfloat::n_significand_bits) - 1;
+ return this_cfloat::from_bits(false, exp_bits, sig_bits);
+ }
+
+ static constexpr auto signaling_NaN() noexcept
+ {
+ const uint64_t exp_bits = (UINT64_C(1) << this_cfloat::n_exponent_bits) - 1;
+ const uint64_t sig_bits = this_cfloat::has_inf ? 1 : (UINT64_C(1) << this_cfloat::n_significand_bits) - 1;
+ return this_cfloat::from_bits(false, exp_bits, sig_bits);
+ }
+
+ static constexpr auto denorm_min() noexcept
+ {
+ return this_cfloat::from_bits(false, 0, 1);
+ }
+
+ static constexpr bool is_iec559 = false;
+ static constexpr bool is_bounded = false;
+ static constexpr bool is_modulo = false;
+
+ static constexpr bool traps = false;
+ static constexpr bool tinyness_before = false;
+ static constexpr float_round_style round_style = round_to_nearest;
+};
+
+} // namespace std
+
+#endif // CT_CFLOAT_H
diff --git a/include/float_utils.h b/include/float_utils.h
deleted file mode 100644
index 831ad74..0000000
--- a/include/float_utils.h
+++ /dev/null
@@ -1,533 +0,0 @@
-// Copyright (c) 2024, ARM Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TOSA_FLOAT_UTILS_H_
-#define TOSA_FLOAT_UTILS_H_
-
-#include <algorithm>
-#include <cstdint>
-#include <limits>
-#include <type_traits>
-#if defined(__cpp_lib_bit_cast)
-#include <bit>
-#endif // defined(__cpp_lib_bit_cast)
-
-namespace tosa
-{
-
-namespace float_support
-{
-
-struct hidden
-{};
-
-#if defined(__cpp_lib_bit_cast)
-#define BITCAST_CONSTEXPR constexpr inline
-
-constexpr inline int32_t get_bits(const float& f)
-{
- return std::bit_cast<int32_t>(f);
-}
-constexpr inline float from_bits(const int32_t& i)
-{
- return std::bit_cast<float>(i);
-}
-
-#else
-#define BITCAST_CONSTEXPR inline
-
-union ufloat32
-{
- constexpr ufloat32(const float& x)
- : f(x)
- {}
- constexpr ufloat32(const int32_t& x)
- : i(x)
- {}
-
- float f;
- int32_t i;
-};
-
-inline int32_t get_bits(const float& f)
-{
- return ufloat32(f).i;
-}
-inline float from_bits(const int32_t& i)
-{
- return ufloat32(i).f;
-}
-#endif
-
-} // namespace float_support
-
-template <typename storage_t,
- size_t n_exp_bits,
- bool has_nan,
- bool with_denorm,
- bool with_infinity,
- std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true>
-class float_t
-{
- storage_t m_data = 0;
-
-public:
- static constexpr size_t n_exponent_bits = n_exp_bits;
- static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits;
- static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1;
-
- /// \brief Construct a floating point type with the given bit
- /// representation.
- static constexpr float_t from_bits(storage_t bits)
- {
- return float_t(float_support::hidden(), bits);
- }
-
- /// \brief Construct a float from the given sign, exponent and significand
- /// bits.
- static constexpr float_t from_bits(bool pm, storage_t e, storage_t s)
- {
- storage_t bits = pm ? 1 : 0;
-
- bits <<= n_exp_bits;
- bits |= e;
-
- bits <<= n_significand_bits;
- if (with_denorm || e)
- bits |= s;
-
- return float_t(float_support::hidden(), bits);
- }
-
- /// \brief (Hidden) Construct a float type from a given bit pattern
- constexpr float_t(const float_support::hidden&, storage_t bits)
- : m_data(bits)
- {}
-
- constexpr float_t()
- : m_data(0)
- {}
- constexpr float_t(const float_t& other)
- : m_data(other.m_data)
- {}
-
- /// \brief Cast to a different floating point representation.
- template <typename other_storage_t,
- size_t other_n_exp_bits,
- bool other_has_nan,
- bool other_has_denorm,
- bool other_has_infinity>
- constexpr inline
- operator float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>() const
- {
- using other_float_t =
- float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>;
-
- // Shortcut for types which are fundamentally similar (e.g., bf16 ->
- // fp32)
- if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) &&
- has_nan == other_has_nan)
- {
- return other_float_t::from_bits(static_cast<other_storage_t>(m_data)
- << (sizeof(other_storage_t) - sizeof(storage_t)) * 8);
- }
-
- // Get initial values for the new floating point type
- const bool sign_bit = m_data < 0;
- int64_t new_exponent_bits = 0;
- uint64_t new_significand = 0;
-
- if (is_nan() || is_infinity())
- {
- new_exponent_bits = (1 << other_n_exp_bits) - 1;
-
- if (is_nan())
- {
- if constexpr (other_has_infinity)
- {
- // Copy across the `not_quiet bit`; set the LSB. Don't
- // attempt to copy across any of the rest of the payload.
- new_significand =
- 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits);
- }
- else
- {
- new_significand = (1ul << other_float_t::n_significand_bits) - 1;
- }
- }
- else if constexpr (!other_has_infinity)
- {
- new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
- }
- }
- else if (!is_zero())
- {
- const int64_t this_exponent_bits = exponent_bits();
- {
- constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias;
- new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1);
- }
- new_significand = this->significand() << (64 - n_significand_bits);
-
- // Normalise subnormals
- if (this_exponent_bits == 0)
- {
- // Shift the most-significant 1 out of the magnitude to convert
- // it to a significand. Modify the exponent accordingly.
- uint8_t shift = __builtin_clzl(new_significand) + 1;
- new_exponent_bits -= shift;
- new_significand <<= shift;
- }
-
- // Align the significand for the output type
- uint32_t shift = 64 - other_float_t::n_significand_bits;
- const bool other_is_subnormal = new_exponent_bits <= 0;
- if (other_is_subnormal)
- {
- shift += 1 - new_exponent_bits;
- new_exponent_bits = 0;
- }
-
- const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1);
- new_significand = shift == 64 ? 0 : new_significand >> shift;
-
- // Reinsert the most-significant-one if this is a subnormal in the
- // output type.
- new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift);
-
- // Apply rounding based on the bits shifted out of the significand
- const uint64_t shift_half = 1ll << (shift - 1);
- if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1)))
- {
- new_significand += 1;
-
- // Handle the case that the significand overflowed due to
- // rounding
- constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1;
- if (new_significand > max_significand)
- {
- new_significand = 0;
- new_exponent_bits++;
- }
- }
-
- // Saturate to infinity if the exponent is larger than can be
- // represented in the output type. This can only occur if the size
- // of the exponent of the new type is not greater than the exponent
- // of the old type.
- if constexpr (other_n_exp_bits <= n_exp_bits)
- {
- constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1;
- if (new_exponent_bits >= inf_exp_bits)
- {
- new_exponent_bits = inf_exp_bits;
- new_significand =
- other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
- }
- }
- }
-
- return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand);
- }
-
- /// \brief Convert from a 32-bit floating point value
- BITCAST_CONSTEXPR
- float_t(const float& f)
- {
- // If this format exactly represents the binary32 format then get
- // the bits from the provided float; otherwise get a binary32
- // representation and then convert to this format.
- if constexpr (represents_binary32())
- m_data = float_support::get_bits(f);
- else
- m_data = static_cast<float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_infinity>>(
- static_cast<float_t<int32_t, 8, true, true, true>>(f))
- .m_data;
- }
-
- /// \brief Cast to a 32-bit floating point value
- BITCAST_CONSTEXPR operator float() const
- {
- // If this format exactly represents the binary32 format then return
- // a float; otherwise get a binary32 representation and then return
- // a float.
- if constexpr (represents_binary32())
- return float_support::from_bits(m_data);
- else
- return static_cast<float>(this->operator float_t<int32_t, 8, true, true, true>());
- }
-
- /// \brief Return whether this type represents the IEEE754 binary32
- /// format
- constexpr static inline bool represents_binary32()
- {
- return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && has_nan && with_denorm && with_infinity;
- }
-
- constexpr auto operator-() const
- {
- return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1)));
- }
-
- constexpr bool is_subnormal() const
- {
- return exponent_bits() == 0 && significand() != 0;
- }
-
- constexpr bool is_zero() const
- {
- return exponent_bits() == 0 && significand() == 0;
- }
-
- constexpr bool is_nan() const
- {
- return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) &&
- ((with_infinity && significand()) ||
- (!with_infinity && significand() == (1ul << n_significand_bits) - 1));
- }
-
- constexpr bool is_infinity() const
- {
- return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand());
- }
-
- constexpr inline const storage_t& bits() const
- {
- return m_data;
- }
-
- /// \brief Get the exponent
- constexpr inline int64_t exponent() const
- {
- return std::max<int64_t>(exponent_bits(), 1ul) - exponent_bias;
- }
-
- /// \brief Get the bits from the exponent
- constexpr inline uint64_t exponent_bits() const
- {
- constexpr uint64_t mask = (1ul << n_exp_bits) - 1;
- return (m_data >> n_significand_bits) & mask;
- }
-
- constexpr inline uint64_t significand() const
- {
- return m_data & ((1ul << n_significand_bits) - 1);
- }
-
- constexpr inline bool operator==(const float_t& other) const
- {
- return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits());
- }
-
- constexpr inline float_t& operator+=(const float_t& rhs)
- {
- this->m_data = static_cast<float_t>(static_cast<float>(*this) + static_cast<float>(rhs)).bits();
- return *this;
- }
-};
-
-// This should probably be exported so we can use it elsewhere
-#undef BITCAST_CONSTEXPR
-
-namespace float_support
-{
-
-// Pre-C++23 these can't be computed as constexpr, so have to hardcode them
-
-template <int>
-struct digits10; // floor(log10(2) * (digits - 1)
-template <int>
-struct max_digits10; // ceil(log10(2) * digits + 1)
-template <int>
-struct min_exponent10; // floor(log10(2) * min_exponent)
-template <int>
-struct max_exponent10; // floor(log10(2) * max_exponent)
-
-template <>
-struct digits10<8>
-{
- constexpr static inline int value = 2;
-};
-
-template <>
-struct max_digits10<8>
-{
- constexpr static inline int value = 4;
-};
-
-template <>
-struct digits10<10>
-{
- constexpr static inline int value = 2;
-};
-
-template <>
-struct max_digits10<10>
-{
- constexpr static inline int value = 5;
-};
-
-template <>
-struct digits10<24>
-{
- constexpr static inline int value = 6;
-};
-
-template <>
-struct max_digits10<24>
-{
- constexpr static inline int value = 9;
-};
-
-template <>
-struct min_exponent10<-13>
-{
- constexpr static inline int value = -3;
-};
-
-template <>
-struct max_exponent10<16>
-{
- constexpr static inline int value = 4;
-};
-
-template <>
-struct min_exponent10<-125>
-{
- constexpr static inline int value = -37;
-};
-
-template <>
-struct max_exponent10<128>
-{
- constexpr static inline int value = 38;
-};
-
-template <int d>
-inline constexpr int digits10_v = digits10<d>::value;
-template <int d>
-inline constexpr int max_digits10_v = max_digits10<d>::value;
-
-template <int e>
-inline constexpr int min_exponent10_v = min_exponent10<e>::value;
-
-template <int e>
-inline constexpr int max_exponent10_v = max_exponent10<e>::value;
-
-} // namespace float_support
-
-} // namespace tosa
-
-namespace std
-{
-
-template <typename storage_t, size_t n_exp_bits, bool has_nan, bool has_denorm, bool has_inf>
-struct is_floating_point<tosa::float_t<storage_t, n_exp_bits, has_nan, has_denorm, has_inf>>
- : std::integral_constant<bool, true>
-{};
-
-template <typename storage_t, size_t n_exp_bits, bool has_nan, bool with_denorm, bool with_inf>
-class numeric_limits<tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>>
-{
- using this_float_t = tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>;
-
-public:
- static constexpr bool is_specialized = true;
-
- static constexpr auto min() noexcept
- {
- return this_float_t::from_bits(false, 1, 0);
- }
-
- static constexpr auto max() noexcept
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2,
- (1 << this_float_t::n_significand_bits) - 1);
- }
-
- static constexpr auto lowest() noexcept
- {
- return -max();
- }
-
- static constexpr int digits = this_float_t::n_significand_bits + 1;
- static constexpr int digits10 = tosa::float_support::digits10_v<digits>;
- static constexpr int max_digits10 = tosa::float_support::max_digits10_v<digits>;
-
- static constexpr bool is_signed = true;
- static constexpr bool is_integer = false;
- static constexpr bool is_exact = false;
- static constexpr int radix = 2;
-
- static constexpr auto epsilon() noexcept
- {
- return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0);
- }
-
- static constexpr auto round_error() noexcept
- {
- return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0);
- }
-
- static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1;
- static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v<min_exponent>;
- static constexpr int max_exponent = this_float_t::exponent_bias + 1;
- static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v<max_exponent>;
-
- static constexpr bool has_infinity = with_inf;
- static constexpr bool has_quiet_NaN = has_nan;
- static constexpr bool has_signaling_NaN = true;
- static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent;
- static constexpr bool has_denorm_loss = false;
-
- static constexpr auto infinity() noexcept
- {
- if constexpr (with_inf)
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0);
- }
- else
- {
- return this_float_t::from_bits(false, 0, 0);
- }
- }
-
- static constexpr auto quiet_NaN() noexcept
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1,
- 1 << (this_float_t::n_significand_bits - 1) | 1);
- }
-
- static constexpr auto signaling_NaN() noexcept
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1);
- }
-
- static constexpr auto denorm_min() noexcept
- {
- return this_float_t::from_bits(false, 0, 1);
- }
-
- static constexpr bool is_iec559 = false;
- static constexpr bool is_bounded = false;
- static constexpr bool is_modulo = false;
-
- static constexpr bool traps = false;
- static constexpr bool tinyness_before = false;
- static constexpr float_round_style round_style = round_to_nearest;
-};
-
-} // namespace std
-
-#endif // TOSA_FLOAT_UTILS_H_
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 60cf77e..ade2f2d 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -24,8 +24,13 @@
#include <cstring>
#include <vector>
+#include "cfloat.h"
#include "half.hpp"
+using bf16 = ct::cfloat<int16_t, 8, true, true, true>;
+using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
+using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
+
class NumpyUtilities
{
public:
@@ -85,6 +90,18 @@ public:
{
return "'<f2'";
}
+ if (std::is_same<T, bf16>::value)
+ {
+ return "'<V2'";
+ }
+ if (std::is_same<T, fp8e4m3>::value)
+ {
+ return "'<V1'";
+ }
+ if (std::is_same<T, fp8e5m2>::value)
+ {
+ return "'<f1'";
+ }
assert(false && "unsupported Dtype");
};
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 0798256..c907c89 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -8,9 +8,9 @@
// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
-static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
- FLATBUFFERS_VERSION_MINOR == 5 &&
- FLATBUFFERS_VERSION_REVISION == 26,
+static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
+ FLATBUFFERS_VERSION_MINOR == 3 &&
+ FLATBUFFERS_VERSION_REVISION == 7,
"Non-compatible flatbuffers version included");
namespace tosa {
@@ -883,11 +883,10 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_OUT_PAD = 4,
VT_STRIDE = 6,
- VT_OUTPUT_SHAPE = 8,
- VT_INPUT_ZP = 10,
- VT_WEIGHT_ZP = 12,
- VT_LOCAL_BOUND = 14,
- VT_ACC_TYPE = 16
+ VT_INPUT_ZP = 8,
+ VT_WEIGHT_ZP = 10,
+ VT_LOCAL_BOUND = 12,
+ VT_ACC_TYPE = 14
};
const ::flatbuffers::Vector<int32_t> *out_pad() const {
return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUT_PAD);
@@ -895,9 +894,6 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T
const ::flatbuffers::Vector<int32_t> *stride() const {
return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_STRIDE);
}
- const ::flatbuffers::Vector<int32_t> *output_shape() const {
- return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SHAPE);
- }
int32_t input_zp() const {
return GetField<int32_t>(VT_INPUT_ZP, 0);
}
@@ -916,8 +912,6 @@ struct TransposeConvAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::T
verifier.VerifyVector(out_pad()) &&
VerifyOffset(verifier, VT_STRIDE) &&
verifier.VerifyVector(stride()) &&
- VerifyOffset(verifier, VT_OUTPUT_SHAPE) &&
- verifier.VerifyVector(output_shape()) &&
VerifyField<int32_t>(verifier, VT_INPUT_ZP, 4) &&
VerifyField<int32_t>(verifier, VT_WEIGHT_ZP, 4) &&
VerifyField<uint8_t>(verifier, VT_LOCAL_BOUND, 1) &&
@@ -936,9 +930,6 @@ struct TransposeConvAttributeBuilder {
void add_stride(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride) {
fbb_.AddOffset(TransposeConvAttribute::VT_STRIDE, stride);
}
- void add_output_shape(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> output_shape) {
- fbb_.AddOffset(TransposeConvAttribute::VT_OUTPUT_SHAPE, output_shape);
- }
void add_input_zp(int32_t input_zp) {
fbb_.AddElement<int32_t>(TransposeConvAttribute::VT_INPUT_ZP, input_zp, 0);
}
@@ -966,7 +957,6 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
::flatbuffers::FlatBufferBuilder &_fbb,
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> out_pad = 0,
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> stride = 0,
- ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> output_shape = 0,
int32_t input_zp = 0,
int32_t weight_zp = 0,
bool local_bound = false,
@@ -975,7 +965,6 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
builder_.add_acc_type(acc_type);
builder_.add_weight_zp(weight_zp);
builder_.add_input_zp(input_zp);
- builder_.add_output_shape(output_shape);
builder_.add_stride(stride);
builder_.add_out_pad(out_pad);
builder_.add_local_bound(local_bound);
@@ -986,19 +975,16 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
::flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<int32_t> *out_pad = nullptr,
const std::vector<int32_t> *stride = nullptr,
- const std::vector<int32_t> *output_shape = nullptr,
int32_t input_zp = 0,
int32_t weight_zp = 0,
bool local_bound = false,
tosa::DType acc_type = tosa::DType_UNKNOWN) {
auto out_pad__ = out_pad ? _fbb.CreateVector<int32_t>(*out_pad) : 0;
auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
- auto output_shape__ = output_shape ? _fbb.CreateVector<int32_t>(*output_shape) : 0;
return tosa::CreateTransposeConvAttribute(
_fbb,
out_pad__,
stride__,
- output_shape__,
input_zp,
weight_zp,
local_bound,
@@ -1008,20 +994,15 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef PadAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_PAD_CONST = 4,
- VT_TYPE = 6
+ VT_PAD_CONST = 4
};
const ::flatbuffers::Vector<uint8_t> *pad_const() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_PAD_CONST);
}
- tosa::DType type() const {
- return static_cast<tosa::DType>(GetField<uint32_t>(VT_TYPE, 0));
- }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_PAD_CONST) &&
verifier.VerifyVector(pad_const()) &&
- VerifyField<uint32_t>(verifier, VT_TYPE, 4) &&
verifier.EndTable();
}
};
@@ -1033,9 +1014,6 @@ struct PadAttributeBuilder {
void add_pad_const(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const) {
fbb_.AddOffset(PadAttribute::VT_PAD_CONST, pad_const);
}
- void add_type(tosa::DType type) {
- fbb_.AddElement<uint32_t>(PadAttribute::VT_TYPE, static_cast<uint32_t>(type), 0);
- }
explicit PadAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1049,24 +1027,20 @@ struct PadAttributeBuilder {
inline ::flatbuffers::Offset<PadAttribute> CreatePadAttribute(
::flatbuffers::FlatBufferBuilder &_fbb,
- ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const = 0,
- tosa::DType type = tosa::DType_UNKNOWN) {
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const = 0) {
PadAttributeBuilder builder_(_fbb);
- builder_.add_type(type);
builder_.add_pad_const(pad_const);
return builder_.Finish();
}
inline ::flatbuffers::Offset<PadAttribute> CreatePadAttributeDirect(
::flatbuffers::FlatBufferBuilder &_fbb,
- const std::vector<uint8_t> *pad_const = nullptr,
- tosa::DType type = tosa::DType_UNKNOWN) {
+ const std::vector<uint8_t> *pad_const = nullptr) {
if (pad_const) { _fbb.ForceVectorAlignment(pad_const->size(), sizeof(uint8_t), 8); }
auto pad_const__ = pad_const ? _fbb.CreateVector<uint8_t>(*pad_const) : 0;
return tosa::CreatePadAttribute(
_fbb,
- pad_const__,
- type);
+ pad_const__);
}
struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
@@ -1205,8 +1179,7 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef ClampAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_MIN_VAL = 4,
- VT_MAX_VAL = 6,
- VT_TYPE = 8
+ VT_MAX_VAL = 6
};
const ::flatbuffers::Vector<uint8_t> *min_val() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MIN_VAL);
@@ -1214,16 +1187,12 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
const ::flatbuffers::Vector<uint8_t> *max_val() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MAX_VAL);
}
- tosa::DType type() const {
- return static_cast<tosa::DType>(GetField<uint32_t>(VT_TYPE, 0));
- }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_MIN_VAL) &&
verifier.VerifyVector(min_val()) &&
VerifyOffset(verifier, VT_MAX_VAL) &&
verifier.VerifyVector(max_val()) &&
- VerifyField<uint32_t>(verifier, VT_TYPE, 4) &&
verifier.EndTable();
}
};
@@ -1238,9 +1207,6 @@ struct ClampAttributeBuilder {
void add_max_val(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val) {
fbb_.AddOffset(ClampAttribute::VT_MAX_VAL, max_val);
}
- void add_type(tosa::DType type) {
- fbb_.AddElement<uint32_t>(ClampAttribute::VT_TYPE, static_cast<uint32_t>(type), 0);
- }
explicit ClampAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1255,10 +1221,8 @@ struct ClampAttributeBuilder {
inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttribute(
::flatbuffers::FlatBufferBuilder &_fbb,
::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> min_val = 0,
- ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val = 0,
- tosa::DType type = tosa::DType_UNKNOWN) {
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val = 0) {
ClampAttributeBuilder builder_(_fbb);
- builder_.add_type(type);
builder_.add_max_val(max_val);
builder_.add_min_val(min_val);
return builder_.Finish();
@@ -1267,8 +1231,7 @@ inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttribute(
inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttributeDirect(
::flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<uint8_t> *min_val = nullptr,
- const std::vector<uint8_t> *max_val = nullptr,
- tosa::DType type = tosa::DType_UNKNOWN) {
+ const std::vector<uint8_t> *max_val = nullptr) {
if (min_val) { _fbb.ForceVectorAlignment(min_val->size(), sizeof(uint8_t), 8); }
auto min_val__ = min_val ? _fbb.CreateVector<uint8_t>(*min_val) : 0;
if (max_val) { _fbb.ForceVectorAlignment(max_val->size(), sizeof(uint8_t), 8); }
@@ -1276,8 +1239,7 @@ inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttributeDirect(
return tosa::CreateClampAttribute(
_fbb,
min_val__,
- max_val__,
- type);
+ max_val__);
}
struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index f5f9e58..c09a47d 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -16,9 +16,9 @@
#ifndef _TOSA_SERIALIZATION_HANDLER_H
#define _TOSA_SERIALIZATION_HANDLER_H
#include "attribute.h"
+#include "cfloat.h"
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
-#include "float_utils.h"
#include "numpy_utils.h"
#include "tosa_generated.h"
#include <cstdint>
@@ -27,8 +27,8 @@
#include <vector>
// Keep version number in sync with the version default value with schema/tosa.fbs
-#define TOSA_VERSION_MAJOR 0
-#define TOSA_VERSION_MINOR 100
+#define TOSA_VERSION_MAJOR 1
+#define TOSA_VERSION_MINOR 1
#define TOSA_VERSION_PATCH 0
#define TOSA_VERSION_DRAFT true
#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
@@ -412,9 +412,9 @@ public:
tosa_err_t LoadFileSchema(const char* schema_filename);
// data format conversion. little-endian.
- static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
@@ -425,9 +425,9 @@ public:
static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out);
+ static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out);
+ static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out);
static tosa_err_t
ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 298907e..34178c5 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
-import struct
+from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2
from enum import IntEnum, unique
from tosa import (
TosaGraph,
@@ -31,8 +31,8 @@ import tosa.DType as TosaDType
import tosa.Op as TosaOp
# Keep version number in sync with the version default value with schema/tosa.fbs
-TOSA_VERSION_MAJOR = 0
-TOSA_VERSION_MINOR = 100
+TOSA_VERSION_MAJOR = 1
+TOSA_VERSION_MINOR = 1
TOSA_VERSION_PATCH = 0
TOSA_VERSION_DRAFT = True
TOSA_VERSION = [
@@ -190,7 +190,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddAccType, acc_type))
def TransposeConvAttribute(
- self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type
+ self, outpad, stride, input_zp, weight_zp, local_bound, acc_type
):
from tosa import TransposeConvAttribute as a, Attribute
@@ -199,13 +199,12 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddOutPad, outpad))
self.intvecs.append((a.AddStride, stride))
- self.intvecs.append((a.AddOutputShape, output_shape))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
self.bools.append((a.AddLocalBound, local_bound))
self.ints.append((a.AddAccType, acc_type))
- def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype):
+ def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
from tosa import PadAttribute as a, Attribute
self.utype = Attribute.Attribute().PadAttribute
@@ -217,7 +216,6 @@ class TosaSerializerAttribute(TosaSerializerUnion):
)
self.floats.append((a.AddPadConst, serialized_pad_const_val))
- self.ints.append((a.AddType, dtype))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -238,9 +236,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.int16vecs.append((a.AddBorder, border))
self.ints.append((a.AddMode, mode))
- def ClampAttribute(
- self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype
- ):
+ def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes):
from tosa import ClampAttribute as a, Attribute
self.utype = Attribute.Attribute().ClampAttribute
@@ -256,7 +252,6 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.floats.append((a.AddMinVal, serialized_min_val))
self.floats.append((a.AddMaxVal, serialized_max_val))
- self.ints.append((a.AddType, dtype))
def RescaleAttribute(
self,
@@ -397,13 +392,14 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if (
- dtype == DType.FP32
- or dtype == DType.BF16
- or dtype == DType.FP8E4M3
- or dtype == DType.FP8E5M2
- ):
+ if dtype == DType.FP32:
fntype = np.float32
+ elif dtype == DType.BF16:
+ fntype = bfloat16
+ elif dtype == DType.FP8E4M3:
+ fntype = float8_e4m3fn
+ elif dtype == DType.FP8E5M2:
+ fntype = float8_e5m2
elif dtype == DType.FP16:
fntype = np.float16
else:
@@ -948,35 +944,19 @@ class TosaSerializer:
np_arr = np.array(data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP32:
- # for val in data:
- # b = struct.pack("!f", val)
- # u8_data.extend([b[3], b[2], b[1], b[0]])
np_arr = np.array(data, dtype=np.float32)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.BF16:
- for val in data:
- # convert val to little endian byte arrays b
- b = struct.pack("<f", val)
- # val => [ b[3], b[2], b[1], b[0] ]
- # keep only most significant 2 bytes for bf16
- # in little endian ordering
- u8_data.extend([b[2], b[3]])
+ np_arr = np.array(data, dtype=bfloat16)
+ u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP8E4M3:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == DType.FP8E5M2:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == TosaDType.DType:
# Serialize DType enum data as uint8 bytes
for val in data:
diff --git a/python/tosa/ClampAttribute.py b/python/tosa/ClampAttribute.py
index 1189acb..40254ec 100644
--- a/python/tosa/ClampAttribute.py
+++ b/python/tosa/ClampAttribute.py
@@ -82,15 +82,8 @@ class ClampAttribute(object):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
- # ClampAttribute
- def Type(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- if o != 0:
- return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
- return 0
-
def ClampAttributeStart(builder):
- builder.StartObject(3)
+ builder.StartObject(2)
def Start(builder):
ClampAttributeStart(builder)
@@ -104,7 +97,7 @@ def AddMinVal(builder, minVal):
def ClampAttributeStartMinValVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
-def StartMinValVector(builder, numElems: int) -> int:
+def StartMinValVector(builder, numElems):
return ClampAttributeStartMinValVector(builder, numElems)
def ClampAttributeAddMaxVal(builder, maxVal):
@@ -116,15 +109,9 @@ def AddMaxVal(builder, maxVal):
def ClampAttributeStartMaxValVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
-def StartMaxValVector(builder, numElems: int) -> int:
+def StartMaxValVector(builder, numElems):
return ClampAttributeStartMaxValVector(builder, numElems)
-def ClampAttributeAddType(builder, type):
- builder.PrependUint32Slot(2, type, 0)
-
-def AddType(builder, type):
- ClampAttributeAddType(builder, type)
-
def ClampAttributeEnd(builder):
return builder.EndObject()
diff --git a/python/tosa/ConvAttribute.py b/python/tosa/ConvAttribute.py
index dfa75dc..1deca59 100644
--- a/python/tosa/ConvAttribute.py
+++ b/python/tosa/ConvAttribute.py
@@ -152,7 +152,7 @@ def AddPad(builder, pad):
def ConvAttributeStartPadVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartPadVector(builder, numElems: int) -> int:
+def StartPadVector(builder, numElems):
return ConvAttributeStartPadVector(builder, numElems)
def ConvAttributeAddStride(builder, stride):
@@ -164,7 +164,7 @@ def AddStride(builder, stride):
def ConvAttributeStartStrideVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartStrideVector(builder, numElems: int) -> int:
+def StartStrideVector(builder, numElems):
return ConvAttributeStartStrideVector(builder, numElems)
def ConvAttributeAddDilation(builder, dilation):
@@ -176,7 +176,7 @@ def AddDilation(builder, dilation):
def ConvAttributeStartDilationVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartDilationVector(builder, numElems: int) -> int:
+def StartDilationVector(builder, numElems):
return ConvAttributeStartDilationVector(builder, numElems)
def ConvAttributeAddInputZp(builder, inputZp):
diff --git a/python/tosa/CustomAttribute.py b/python/tosa/CustomAttribute.py
index db35dca..4c1c477 100644
--- a/python/tosa/CustomAttribute.py
+++ b/python/tosa/CustomAttribute.py
@@ -96,7 +96,7 @@ def AddImplementationAttrs(builder, implementationAttrs):
def CustomAttributeStartImplementationAttrsVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
-def StartImplementationAttrsVector(builder, numElems: int) -> int:
+def StartImplementationAttrsVector(builder, numElems):
return CustomAttributeStartImplementationAttrsVector(builder, numElems)
def CustomAttributeEnd(builder):
diff --git a/python/tosa/PadAttribute.py b/python/tosa/PadAttribute.py
index c4084dc..8adf9f7 100644
--- a/python/tosa/PadAttribute.py
+++ b/python/tosa/PadAttribute.py
@@ -55,15 +55,8 @@ class PadAttribute(object):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
- # PadAttribute
- def Type(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
- if o != 0:
- return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
- return 0
-
def PadAttributeStart(builder):
- builder.StartObject(2)
+ builder.StartObject(1)
def Start(builder):
PadAttributeStart(builder)
@@ -77,15 +70,9 @@ def AddPadConst(builder, padConst):
def PadAttributeStartPadConstVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
-def StartPadConstVector(builder, numElems: int) -> int:
+def StartPadConstVector(builder, numElems):
return PadAttributeStartPadConstVector(builder, numElems)
-def PadAttributeAddType(builder, type):
- builder.PrependUint32Slot(1, type, 0)
-
-def AddType(builder, type):
- PadAttributeAddType(builder, type)
-
def PadAttributeEnd(builder):
return builder.EndObject()
diff --git a/python/tosa/PoolAttribute.py b/python/tosa/PoolAttribute.py
index c13e038..831d43b 100644
--- a/python/tosa/PoolAttribute.py
+++ b/python/tosa/PoolAttribute.py
@@ -145,7 +145,7 @@ def AddPad(builder, pad):
def PoolAttributeStartPadVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartPadVector(builder, numElems: int) -> int:
+def StartPadVector(builder, numElems):
return PoolAttributeStartPadVector(builder, numElems)
def PoolAttributeAddKernel(builder, kernel):
@@ -157,7 +157,7 @@ def AddKernel(builder, kernel):
def PoolAttributeStartKernelVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartKernelVector(builder, numElems: int) -> int:
+def StartKernelVector(builder, numElems):
return PoolAttributeStartKernelVector(builder, numElems)
def PoolAttributeAddStride(builder, stride):
@@ -169,7 +169,7 @@ def AddStride(builder, stride):
def PoolAttributeStartStrideVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartStrideVector(builder, numElems: int) -> int:
+def StartStrideVector(builder, numElems):
return PoolAttributeStartStrideVector(builder, numElems)
def PoolAttributeAddInputZp(builder, inputZp):
diff --git a/python/tosa/ResizeAttribute.py b/python/tosa/ResizeAttribute.py
index 96bfa56..44f7d31 100644
--- a/python/tosa/ResizeAttribute.py
+++ b/python/tosa/ResizeAttribute.py
@@ -131,7 +131,7 @@ def AddScale(builder, scale):
def ResizeAttributeStartScaleVector(builder, numElems):
return builder.StartVector(2, numElems, 2)
-def StartScaleVector(builder, numElems: int) -> int:
+def StartScaleVector(builder, numElems):
return ResizeAttributeStartScaleVector(builder, numElems)
def ResizeAttributeAddOffset(builder, offset):
@@ -143,7 +143,7 @@ def AddOffset(builder, offset):
def ResizeAttributeStartOffsetVector(builder, numElems):
return builder.StartVector(2, numElems, 2)
-def StartOffsetVector(builder, numElems: int) -> int:
+def StartOffsetVector(builder, numElems):
return ResizeAttributeStartOffsetVector(builder, numElems)
def ResizeAttributeAddBorder(builder, border):
@@ -155,7 +155,7 @@ def AddBorder(builder, border):
def ResizeAttributeStartBorderVector(builder, numElems):
return builder.StartVector(2, numElems, 2)
-def StartBorderVector(builder, numElems: int) -> int:
+def StartBorderVector(builder, numElems):
return ResizeAttributeStartBorderVector(builder, numElems)
def ResizeAttributeAddMode(builder, mode):
diff --git a/python/tosa/TableAttribute.py b/python/tosa/TableAttribute.py
index 6caa1f2..04193fa 100644
--- a/python/tosa/TableAttribute.py
+++ b/python/tosa/TableAttribute.py
@@ -70,7 +70,7 @@ def AddTable(builder, table):
def TableAttributeStartTableVector(builder, numElems):
return builder.StartVector(2, numElems, 2)
-def StartTableVector(builder, numElems: int) -> int:
+def StartTableVector(builder, numElems):
return TableAttributeStartTableVector(builder, numElems)
def TableAttributeEnd(builder):
diff --git a/python/tosa/TosaBasicBlock.py b/python/tosa/TosaBasicBlock.py
index b31f455..30ad0ee 100644
--- a/python/tosa/TosaBasicBlock.py
+++ b/python/tosa/TosaBasicBlock.py
@@ -146,7 +146,7 @@ def AddOperators(builder, operators):
def TosaBasicBlockStartOperatorsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartOperatorsVector(builder, numElems: int) -> int:
+def StartOperatorsVector(builder, numElems):
return TosaBasicBlockStartOperatorsVector(builder, numElems)
def TosaBasicBlockAddTensors(builder, tensors):
@@ -158,7 +158,7 @@ def AddTensors(builder, tensors):
def TosaBasicBlockStartTensorsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartTensorsVector(builder, numElems: int) -> int:
+def StartTensorsVector(builder, numElems):
return TosaBasicBlockStartTensorsVector(builder, numElems)
def TosaBasicBlockAddInputs(builder, inputs):
@@ -170,7 +170,7 @@ def AddInputs(builder, inputs):
def TosaBasicBlockStartInputsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartInputsVector(builder, numElems: int) -> int:
+def StartInputsVector(builder, numElems):
return TosaBasicBlockStartInputsVector(builder, numElems)
def TosaBasicBlockAddOutputs(builder, outputs):
@@ -182,7 +182,7 @@ def AddOutputs(builder, outputs):
def TosaBasicBlockStartOutputsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartOutputsVector(builder, numElems: int) -> int:
+def StartOutputsVector(builder, numElems):
return TosaBasicBlockStartOutputsVector(builder, numElems)
def TosaBasicBlockEnd(builder):
diff --git a/python/tosa/TosaGraph.py b/python/tosa/TosaGraph.py
index 84b51a7..520372b 100644
--- a/python/tosa/TosaGraph.py
+++ b/python/tosa/TosaGraph.py
@@ -85,7 +85,7 @@ def AddRegions(builder, regions):
def TosaGraphStartRegionsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartRegionsVector(builder, numElems: int) -> int:
+def StartRegionsVector(builder, numElems):
return TosaGraphStartRegionsVector(builder, numElems)
def TosaGraphEnd(builder):
diff --git a/python/tosa/TosaOperator.py b/python/tosa/TosaOperator.py
index 2b889ad..19f2d2c 100644
--- a/python/tosa/TosaOperator.py
+++ b/python/tosa/TosaOperator.py
@@ -125,7 +125,7 @@ def AddInputs(builder, inputs):
def TosaOperatorStartInputsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartInputsVector(builder, numElems: int) -> int:
+def StartInputsVector(builder, numElems):
return TosaOperatorStartInputsVector(builder, numElems)
def TosaOperatorAddOutputs(builder, outputs):
@@ -137,7 +137,7 @@ def AddOutputs(builder, outputs):
def TosaOperatorStartOutputsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartOutputsVector(builder, numElems: int) -> int:
+def StartOutputsVector(builder, numElems):
return TosaOperatorStartOutputsVector(builder, numElems)
def TosaOperatorEnd(builder):
diff --git a/python/tosa/TosaRegion.py b/python/tosa/TosaRegion.py
index 7fd6e3c..80829da 100644
--- a/python/tosa/TosaRegion.py
+++ b/python/tosa/TosaRegion.py
@@ -81,7 +81,7 @@ def AddBlocks(builder, blocks):
def TosaRegionStartBlocksVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartBlocksVector(builder, numElems: int) -> int:
+def StartBlocksVector(builder, numElems):
return TosaRegionStartBlocksVector(builder, numElems)
def TosaRegionEnd(builder):
diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py
index 3fb9f86..1311aac 100644
--- a/python/tosa/TosaTensor.py
+++ b/python/tosa/TosaTensor.py
@@ -138,7 +138,7 @@ def AddShape(builder, shape):
def TosaTensorStartShapeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartShapeVector(builder, numElems: int) -> int:
+def StartShapeVector(builder, numElems):
return TosaTensorStartShapeVector(builder, numElems)
def TosaTensorAddType(builder, type):
@@ -156,7 +156,7 @@ def AddData(builder, data):
def TosaTensorStartDataVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
-def StartDataVector(builder, numElems: int) -> int:
+def StartDataVector(builder, numElems):
return TosaTensorStartDataVector(builder, numElems)
def TosaTensorAddVariable(builder, variable):
diff --git a/python/tosa/TransposeAttribute.py b/python/tosa/TransposeAttribute.py
index 71cfdf0..5aa23e2 100644
--- a/python/tosa/TransposeAttribute.py
+++ b/python/tosa/TransposeAttribute.py
@@ -70,7 +70,7 @@ def AddPerms(builder, perms):
def TransposeAttributeStartPermsVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartPermsVector(builder, numElems: int) -> int:
+def StartPermsVector(builder, numElems):
return TransposeAttributeStartPermsVector(builder, numElems)
def TransposeAttributeEnd(builder):
diff --git a/python/tosa/TransposeConvAttribute.py b/python/tosa/TransposeConvAttribute.py
index e5397a8..2f7cdc7 100644
--- a/python/tosa/TransposeConvAttribute.py
+++ b/python/tosa/TransposeConvAttribute.py
@@ -83,62 +83,35 @@ class TransposeConvAttribute(object):
return o == 0
# TransposeConvAttribute
- def OutputShape(self, j):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- if o != 0:
- a = self._tab.Vector(o)
- return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
- return 0
-
- # TransposeConvAttribute
- def OutputShapeAsNumpy(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- if o != 0:
- return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
- return 0
-
- # TransposeConvAttribute
- def OutputShapeLength(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- if o != 0:
- return self._tab.VectorLen(o)
- return 0
-
- # TransposeConvAttribute
- def OutputShapeIsNone(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- return o == 0
-
- # TransposeConvAttribute
def InputZp(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0
# TransposeConvAttribute
def WeightZp(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0
# TransposeConvAttribute
def LocalBound(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
# TransposeConvAttribute
def AccType(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0
def TransposeConvAttributeStart(builder):
- builder.StartObject(7)
+ builder.StartObject(6)
def Start(builder):
TransposeConvAttributeStart(builder)
@@ -152,7 +125,7 @@ def AddOutPad(builder, outPad):
def TransposeConvAttributeStartOutPadVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartOutPadVector(builder, numElems: int) -> int:
+def StartOutPadVector(builder, numElems):
return TransposeConvAttributeStartOutPadVector(builder, numElems)
def TransposeConvAttributeAddStride(builder, stride):
@@ -164,41 +137,29 @@ def AddStride(builder, stride):
def TransposeConvAttributeStartStrideVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
-def StartStrideVector(builder, numElems: int) -> int:
+def StartStrideVector(builder, numElems):
return TransposeConvAttributeStartStrideVector(builder, numElems)
-def TransposeConvAttributeAddOutputShape(builder, outputShape):
- builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(outputShape), 0)
-
-def AddOutputShape(builder, outputShape):
- TransposeConvAttributeAddOutputShape(builder, outputShape)
-
-def TransposeConvAttributeStartOutputShapeVector(builder, numElems):
- return builder.StartVector(4, numElems, 4)
-
-def StartOutputShapeVector(builder, numElems: int) -> int:
- return TransposeConvAttributeStartOutputShapeVector(builder, numElems)
-
def TransposeConvAttributeAddInputZp(builder, inputZp):
- builder.PrependInt32Slot(3, inputZp, 0)
+ builder.PrependInt32Slot(2, inputZp, 0)
def AddInputZp(builder, inputZp):
TransposeConvAttributeAddInputZp(builder, inputZp)
def TransposeConvAttributeAddWeightZp(builder, weightZp):
- builder.PrependInt32Slot(4, weightZp, 0)
+ builder.PrependInt32Slot(3, weightZp, 0)
def AddWeightZp(builder, weightZp):
TransposeConvAttributeAddWeightZp(builder, weightZp)
def TransposeConvAttributeAddLocalBound(builder, localBound):
- builder.PrependBoolSlot(5, localBound, 0)
+ builder.PrependBoolSlot(4, localBound, 0)
def AddLocalBound(builder, localBound):
TransposeConvAttributeAddLocalBound(builder, localBound)
def TransposeConvAttributeAddAccType(builder, accType):
- builder.PrependUint32Slot(6, accType, 0)
+ builder.PrependUint32Slot(5, accType, 0)
def AddAccType(builder, accType):
TransposeConvAttributeAddAccType(builder, accType)
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 7b5948b..cad6db7 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -176,7 +176,6 @@ table ConvAttribute {
table TransposeConvAttribute {
out_pad: [int32];
stride: [int32];
- output_shape: [int32];
input_zp: int32;
weight_zp: int32;
local_bound: bool;
@@ -185,7 +184,6 @@ table TransposeConvAttribute {
table PadAttribute {
pad_const: [ubyte] (force_align: 8);
- type: DType;
}
table AxisAttribute {
@@ -202,7 +200,6 @@ table ResizeAttribute {
table ClampAttribute {
min_val: [ubyte] (force_align: 8);
max_val: [ubyte] (force_align: 8);
- type: DType;
}
table RescaleAttribute {
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index e4171d7..7cf5f94 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3
while (isspace(*ptr))
ptr++;
+ // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but
+ // default NumPy does not understand this notation, which causes trouble
+ // when other code tries to open this file.
+ // To avoid this, '|u1' notation is used when the file is written, and the uint8
+ // data is viewed as float8_e5m2 later when the file is read.
+ if (!strcmp(dtype_str, "'<f1'"))
+ dtype_str = "'|u1'";
+
if (strcmp(ptr, dtype_str))
{
return FILE_TYPE_MISMATCH;
@@ -430,6 +438,13 @@ NumpyUtilities::NPError
memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
headerPos += sizeof(NUMPY_HEADER_STR) - 1;
+ // NumPy does not understand float8_e5m2, so change it to uint8 type, so that
+ // Python can read .npy files.
+ if (!strcmp(dtype_str, "'<f1'"))
+ {
+ dtype_str = "'|u1'";
+ }
+
// Output the format dictionary
// Hard-coded for I32 for now
headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
@@ -438,7 +453,19 @@ NumpyUtilities::NPError
// Add shape contents (if any - as this will be empty for rank 0)
for (i = 0; i < shape.size(); i++)
{
- headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
+ // Output NumPy file from tosa_refmodel_sut_run generates the shape information
+ // without a trailing comma when the rank is greater than 1.
+ if (i == 0)
+ {
+ if (shape.size() == 1)
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]);
+ else
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]);
+ }
+ else
+ {
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]);
+ }
}
// Close off the dictionary
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 85625cd..74f66d8 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -19,9 +19,6 @@
#include <iostream>
using namespace tosa;
-using fp8e4m3 = tosa::float_t<int8_t, 4, true, true, false>;
-using fp8e5m2 = tosa::float_t<int8_t, 5, true, true, true>;
-
TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
@@ -750,45 +747,41 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
}
}
-tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->bf16 by ignoring the least significant 16 bits
out.clear();
for (auto val : in)
{
- uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val);
- uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF;
- uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF;
- // little endian: byte2 followed by byte3
- out.push_back(f32_byte2);
- out.push_back(f32_byte3);
+ uint8_t bf16_byte0 = val.bits() & 0xFF;
+ uint8_t bf16_byte1 = (val.bits() >> 8) & 0xFF;
+ out.push_back(bf16_byte0);
+ out.push_back(bf16_byte1);
}
ForceAlignTensorData(out);
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->FP8E4M3 before converting to unint8_t
out.clear();
for (auto val : in)
{
- auto f8 = static_cast<fp8e4m3>(val);
- uint8_t b8 = f8.bits();
+ uint8_t b8 = val.bits();
out.push_back(b8);
}
ForceAlignTensorData(out);
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->FP8E5M2 before converting to uint8_t
out.clear();
for (auto val : in)
{
- auto f8 = static_cast<fp8e5m2>(val);
- uint8_t b8 = f8.bits();
+ uint8_t b8 = val.bits();
out.push_back(b8);
}
ForceAlignTensorData(out);
@@ -944,11 +937,9 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in,
- uint32_t out_size,
- std::vector<float>& out)
+tosa_err_t
+ TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out)
{
- // Note: bf16 values returned in fp32 type
out.clear();
if (in.size() < out_size * sizeof(int16_t))
{
@@ -959,22 +950,21 @@ tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>&
for (uint32_t i = 0; i < out_size; i++)
{
- uint32_t f32_byte2 = in[i * sizeof(int16_t)];
- uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1];
- uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24);
+ uint8_t bf16_byte0 = in[i * sizeof(int16_t)];
+ uint8_t bf16_byte1 = in[i * sizeof(int16_t) + 1];
+ uint16_t val_u16 = (bf16_byte0) + (bf16_byte1 << 8);
- // Reinterpret u32 bytes as fp32
- float val_f32 = *(float*)&val_u32;
- out.push_back(val_f32);
+ // Reinterpret u16 bytes as bf16
+ bf16 val_bf16 = *(bf16*)&val_u16;
+ out.push_back(val_bf16);
}
return TOSA_OK;
}
tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in,
uint32_t out_size,
- std::vector<float>& out)
+ std::vector<fp8e4m3>& out)
{
- // Note: FP8E4M3 values returned in fp32 type
out.clear();
if (in.size() < out_size * sizeof(int8_t))
{
@@ -985,17 +975,16 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_
for (uint32_t i = 0; i < out_size; i++)
{
- int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
- auto f8 = fp8e4m3::from_bits(bits);
- float val_f32 = static_cast<float>(f8);
- out.push_back(val_f32);
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e4m3::from_bits(bits);
+ out.push_back(f8);
}
return TOSA_OK;
}
tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in,
uint32_t out_size,
- std::vector<float>& out)
+ std::vector<fp8e5m2>& out)
{
// Note: FP8E5M2 values returned in fp32 type
out.clear();
@@ -1008,10 +997,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_
for (uint32_t i = 0; i < out_size; i++)
{
- int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
- auto f8 = fp8e5m2::from_bits(bits);
- float val_f32 = static_cast<float>(f8);
- out.push_back(val_f32);
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e5m2::from_bits(bits);
+ out.push_back(f8);
}
return TOSA_OK;
}
@@ -1031,9 +1019,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>&
for (uint32_t i = 0; i < out_size; i++)
{
- uint16_t f16_byte0 = in[i * sizeof(int16_t)];
- uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1];
- uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8);
+ uint8_t f16_byte0 = in[i * sizeof(int16_t)];
+ uint8_t f16_byte1 = in[i * sizeof(int16_t) + 1];
+ uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8);
// Reinterpret u16 byte as fp16 then convert to fp32
half_float::half val_f16 = *(half_float::half*)&val_u16;
diff --git a/third_party/flatbuffers b/third_party/flatbuffers
-Subproject 0100f6a5779831fa7a651e4b67ef389a8752bd9
+Subproject 6ff9e90e7e399f3977e99a315856b57c8afe5b4