aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-21 17:01:14 +0000
committerTai Ly <tai.ly@arm.com>2024-04-08 22:18:34 +0000
commitce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a (patch)
tree68faf6d7b1c676705a022b32d8aa7950db03ab5e
parent8f9e2842ce7d25645233ad4f6fa406be982346ae (diff)
downloadserialization_lib-ce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a.tar.gz
Add conversions of U8 to/from BF16 and FP8
Adds type to PadAttribute and ClampAttribute so their pad_const and max_val/min_val can be deserialized according to type Adds conversion functions of U8 arrays to/from BF16/FP8 values Also, refactor and expose TosaSerializer.convertDataToUint8Vec for converting dtype/data to uint8 list for serialization And modify convertDataToUint8Vec to serialize bf16 values into 2 bytes each, and serialize fp8 values into single bytes each. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f
-rw-r--r--CMakeLists.txt5
-rw-r--r--include/attribute.def10
-rw-r--r--include/float_utils.h533
-rw-r--r--include/tosa_generated.h40
-rw-r--r--include/tosa_serialization_handler.h7
-rw-r--r--python/serializer/tosa_serializer.py193
-rw-r--r--python/tosa/ClampAttribute.py15
-rw-r--r--python/tosa/PadAttribute.py15
-rw-r--r--schema/tosa.fbs2
-rw-r--r--src/tosa_serialization_handler.cpp120
10 files changed, 841 insertions, 99 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ac34b75..5f4f851 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -19,8 +19,8 @@
cmake_minimum_required(VERSION 3.13.4)
project(TosaSerialization)
-set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to")
-set(CMAKE_CXX_STANDARD_REQUIRED YES)
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_VERBOSE_MAKEFILE ON)
@@ -76,6 +76,7 @@ set(public_headers)
list(APPEND public_headers
include/attribute.h
include/attribute.def
+ include/float_utils.h
include/numpy_utils.h
include/tosa_generated.h
include/tosa_serialization_handler.h
diff --git a/include/attribute.def b/include/attribute.def
index 723543e..30b432d 100644
--- a/include/attribute.def
+++ b/include/attribute.def
@@ -52,8 +52,9 @@ DEF_ATTRIBUTE(TransposeConv, 7,
bool, S, local_bound,
DType, S, acc_type)
-DEF_ATTRIBUTE(Pad, 1,
- uint8_t, V, pad_const)
+DEF_ATTRIBUTE(Pad, 2,
+ uint8_t, V, pad_const,
+ DType, S, type)
DEF_ATTRIBUTE(Axis, 1,
int32_t, S, axis)
@@ -64,9 +65,10 @@ DEF_ATTRIBUTE(Resize, 4,
int16_t, V, border,
ResizeMode, S, mode)
-DEF_ATTRIBUTE(Clamp, 2,
+DEF_ATTRIBUTE(Clamp, 3,
uint8_t, V, min_val,
- uint8_t, V, max_val)
+ uint8_t, V, max_val,
+ DType, S, type)
DEF_ATTRIBUTE(Rescale, 7,
int32_t, S, input_zp,
diff --git a/include/float_utils.h b/include/float_utils.h
new file mode 100644
index 0000000..831ad74
--- /dev/null
+++ b/include/float_utils.h
@@ -0,0 +1,533 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_FLOAT_UTILS_H_
+#define TOSA_FLOAT_UTILS_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <type_traits>
+#if defined(__cpp_lib_bit_cast)
+#include <bit>
+#endif // defined(__cpp_lib_bit_cast)
+
+namespace tosa
+{
+
+namespace float_support
+{
+
+struct hidden
+{};
+
+#if defined(__cpp_lib_bit_cast)
+#define BITCAST_CONSTEXPR constexpr inline
+
+constexpr inline int32_t get_bits(const float& f)
+{
+ return std::bit_cast<int32_t>(f);
+}
+constexpr inline float from_bits(const int32_t& i)
+{
+ return std::bit_cast<float>(i);
+}
+
+#else
+#define BITCAST_CONSTEXPR inline
+
+union ufloat32
+{
+ constexpr ufloat32(const float& x)
+ : f(x)
+ {}
+ constexpr ufloat32(const int32_t& x)
+ : i(x)
+ {}
+
+ float f;
+ int32_t i;
+};
+
+inline int32_t get_bits(const float& f)
+{
+ return ufloat32(f).i;
+}
+inline float from_bits(const int32_t& i)
+{
+ return ufloat32(i).f;
+}
+#endif
+
+} // namespace float_support
+
+template <typename storage_t,
+ size_t n_exp_bits,
+ bool has_nan,
+ bool with_denorm,
+ bool with_infinity,
+ std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true>
+class float_t
+{
+ storage_t m_data = 0;
+
+public:
+ static constexpr size_t n_exponent_bits = n_exp_bits;
+ static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits;
+ static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1;
+
+ /// \brief Construct a floating point type with the given bit
+ /// representation.
+ static constexpr float_t from_bits(storage_t bits)
+ {
+ return float_t(float_support::hidden(), bits);
+ }
+
+ /// \brief Construct a float from the given sign, exponent and significand
+ /// bits.
+ static constexpr float_t from_bits(bool pm, storage_t e, storage_t s)
+ {
+ storage_t bits = pm ? 1 : 0;
+
+ bits <<= n_exp_bits;
+ bits |= e;
+
+ bits <<= n_significand_bits;
+ if (with_denorm || e)
+ bits |= s;
+
+ return float_t(float_support::hidden(), bits);
+ }
+
+ /// \brief (Hidden) Construct a float type from a given bit pattern
+ constexpr float_t(const float_support::hidden&, storage_t bits)
+ : m_data(bits)
+ {}
+
+ constexpr float_t()
+ : m_data(0)
+ {}
+ constexpr float_t(const float_t& other)
+ : m_data(other.m_data)
+ {}
+
+ /// \brief Cast to a different floating point representation.
+ template <typename other_storage_t,
+ size_t other_n_exp_bits,
+ bool other_has_nan,
+ bool other_has_denorm,
+ bool other_has_infinity>
+ constexpr inline
+ operator float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>() const
+ {
+ using other_float_t =
+ float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>;
+
+ // Shortcut for types which are fundamentally similar (e.g., bf16 ->
+ // fp32)
+ if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) &&
+ has_nan == other_has_nan)
+ {
+ return other_float_t::from_bits(static_cast<other_storage_t>(m_data)
+ << (sizeof(other_storage_t) - sizeof(storage_t)) * 8);
+ }
+
+ // Get initial values for the new floating point type
+ const bool sign_bit = m_data < 0;
+ int64_t new_exponent_bits = 0;
+ uint64_t new_significand = 0;
+
+ if (is_nan() || is_infinity())
+ {
+ new_exponent_bits = (1 << other_n_exp_bits) - 1;
+
+ if (is_nan())
+ {
+ if constexpr (other_has_infinity)
+ {
+ // Copy across the `not_quiet bit`; set the LSB. Don't
+ // attempt to copy across any of the rest of the payload.
+ new_significand =
+ 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits);
+ }
+ else
+ {
+ new_significand = (1ul << other_float_t::n_significand_bits) - 1;
+ }
+ }
+ else if constexpr (!other_has_infinity)
+ {
+ new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
+ }
+ }
+ else if (!is_zero())
+ {
+ const int64_t this_exponent_bits = exponent_bits();
+ {
+ constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias;
+ new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1);
+ }
+ new_significand = this->significand() << (64 - n_significand_bits);
+
+ // Normalise subnormals
+ if (this_exponent_bits == 0)
+ {
+ // Shift the most-significant 1 out of the magnitude to convert
+ // it to a significand. Modify the exponent accordingly.
+ uint8_t shift = __builtin_clzl(new_significand) + 1;
+ new_exponent_bits -= shift;
+ new_significand <<= shift;
+ }
+
+ // Align the significand for the output type
+ uint32_t shift = 64 - other_float_t::n_significand_bits;
+ const bool other_is_subnormal = new_exponent_bits <= 0;
+ if (other_is_subnormal)
+ {
+ shift += 1 - new_exponent_bits;
+ new_exponent_bits = 0;
+ }
+
+ const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1);
+ new_significand = shift == 64 ? 0 : new_significand >> shift;
+
+ // Reinsert the most-significant-one if this is a subnormal in the
+ // output type.
+ new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift);
+
+ // Apply rounding based on the bits shifted out of the significand
+ const uint64_t shift_half = 1ll << (shift - 1);
+ if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1)))
+ {
+ new_significand += 1;
+
+ // Handle the case that the significand overflowed due to
+ // rounding
+ constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1;
+ if (new_significand > max_significand)
+ {
+ new_significand = 0;
+ new_exponent_bits++;
+ }
+ }
+
+ // Saturate to infinity if the exponent is larger than can be
+ // represented in the output type. This can only occur if the size
+ // of the exponent of the new type is not greater than the exponent
+ // of the old type.
+ if constexpr (other_n_exp_bits <= n_exp_bits)
+ {
+ constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1;
+ if (new_exponent_bits >= inf_exp_bits)
+ {
+ new_exponent_bits = inf_exp_bits;
+ new_significand =
+ other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
+ }
+ }
+ }
+
+ return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand);
+ }
+
+ /// \brief Convert from a 32-bit floating point value
+ BITCAST_CONSTEXPR
+ float_t(const float& f)
+ {
+ // If this format exactly represents the binary32 format then get
+ // the bits from the provided float; otherwise get a binary32
+ // representation and then convert to this format.
+ if constexpr (represents_binary32())
+ m_data = float_support::get_bits(f);
+ else
+ m_data = static_cast<float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_infinity>>(
+ static_cast<float_t<int32_t, 8, true, true, true>>(f))
+ .m_data;
+ }
+
+ /// \brief Cast to a 32-bit floating point value
+ BITCAST_CONSTEXPR operator float() const
+ {
+ // If this format exactly represents the binary32 format then return
+ // a float; otherwise get a binary32 representation and then return
+ // a float.
+ if constexpr (represents_binary32())
+ return float_support::from_bits(m_data);
+ else
+ return static_cast<float>(this->operator float_t<int32_t, 8, true, true, true>());
+ }
+
+ /// \brief Return whether this type represents the IEEE754 binary32
+ /// format
+ constexpr static inline bool represents_binary32()
+ {
+ return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && has_nan && with_denorm && with_infinity;
+ }
+
+ constexpr auto operator-() const
+ {
+ return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1)));
+ }
+
+ constexpr bool is_subnormal() const
+ {
+ return exponent_bits() == 0 && significand() != 0;
+ }
+
+ constexpr bool is_zero() const
+ {
+ return exponent_bits() == 0 && significand() == 0;
+ }
+
+ constexpr bool is_nan() const
+ {
+ return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) &&
+ ((with_infinity && significand()) ||
+ (!with_infinity && significand() == (1ul << n_significand_bits) - 1));
+ }
+
+ constexpr bool is_infinity() const
+ {
+ return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand());
+ }
+
+ constexpr inline const storage_t& bits() const
+ {
+ return m_data;
+ }
+
+ /// \brief Get the exponent
+ constexpr inline int64_t exponent() const
+ {
+ return std::max<int64_t>(exponent_bits(), 1ul) - exponent_bias;
+ }
+
+ /// \brief Get the bits from the exponent
+ constexpr inline uint64_t exponent_bits() const
+ {
+ constexpr uint64_t mask = (1ul << n_exp_bits) - 1;
+ return (m_data >> n_significand_bits) & mask;
+ }
+
+ constexpr inline uint64_t significand() const
+ {
+ return m_data & ((1ul << n_significand_bits) - 1);
+ }
+
+ constexpr inline bool operator==(const float_t& other) const
+ {
+ return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits());
+ }
+
+ constexpr inline float_t& operator+=(const float_t& rhs)
+ {
+ this->m_data = static_cast<float_t>(static_cast<float>(*this) + static_cast<float>(rhs)).bits();
+ return *this;
+ }
+};
+
+// This should probably be exported so we can use it elsewhere
+#undef BITCAST_CONSTEXPR
+
+namespace float_support
+{
+
+// Pre-C++23 these can't be computed as constexpr, so have to hardcode them
+
+template <int>
+struct digits10; // floor(log10(2) * (digits - 1)
+template <int>
+struct max_digits10; // ceil(log10(2) * digits + 1)
+template <int>
+struct min_exponent10; // floor(log10(2) * min_exponent)
+template <int>
+struct max_exponent10; // floor(log10(2) * max_exponent)
+
+template <>
+struct digits10<8>
+{
+ constexpr static inline int value = 2;
+};
+
+template <>
+struct max_digits10<8>
+{
+ constexpr static inline int value = 4;
+};
+
+template <>
+struct digits10<10>
+{
+ constexpr static inline int value = 2;
+};
+
+template <>
+struct max_digits10<10>
+{
+ constexpr static inline int value = 5;
+};
+
+template <>
+struct digits10<24>
+{
+ constexpr static inline int value = 6;
+};
+
+template <>
+struct max_digits10<24>
+{
+ constexpr static inline int value = 9;
+};
+
+template <>
+struct min_exponent10<-13>
+{
+ constexpr static inline int value = -3;
+};
+
+template <>
+struct max_exponent10<16>
+{
+ constexpr static inline int value = 4;
+};
+
+template <>
+struct min_exponent10<-125>
+{
+ constexpr static inline int value = -37;
+};
+
+template <>
+struct max_exponent10<128>
+{
+ constexpr static inline int value = 38;
+};
+
+template <int d>
+inline constexpr int digits10_v = digits10<d>::value;
+template <int d>
+inline constexpr int max_digits10_v = max_digits10<d>::value;
+
+template <int e>
+inline constexpr int min_exponent10_v = min_exponent10<e>::value;
+
+template <int e>
+inline constexpr int max_exponent10_v = max_exponent10<e>::value;
+
+} // namespace float_support
+
+} // namespace tosa
+
+namespace std
+{
+
+template <typename storage_t, size_t n_exp_bits, bool has_nan, bool has_denorm, bool has_inf>
+struct is_floating_point<tosa::float_t<storage_t, n_exp_bits, has_nan, has_denorm, has_inf>>
+ : std::integral_constant<bool, true>
+{};
+
+template <typename storage_t, size_t n_exp_bits, bool has_nan, bool with_denorm, bool with_inf>
+class numeric_limits<tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>>
+{
+ using this_float_t = tosa::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>;
+
+public:
+ static constexpr bool is_specialized = true;
+
+ static constexpr auto min() noexcept
+ {
+ return this_float_t::from_bits(false, 1, 0);
+ }
+
+ static constexpr auto max() noexcept
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2,
+ (1 << this_float_t::n_significand_bits) - 1);
+ }
+
+ static constexpr auto lowest() noexcept
+ {
+ return -max();
+ }
+
+ static constexpr int digits = this_float_t::n_significand_bits + 1;
+ static constexpr int digits10 = tosa::float_support::digits10_v<digits>;
+ static constexpr int max_digits10 = tosa::float_support::max_digits10_v<digits>;
+
+ static constexpr bool is_signed = true;
+ static constexpr bool is_integer = false;
+ static constexpr bool is_exact = false;
+ static constexpr int radix = 2;
+
+ static constexpr auto epsilon() noexcept
+ {
+ return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0);
+ }
+
+ static constexpr auto round_error() noexcept
+ {
+ return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0);
+ }
+
+ static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1;
+ static constexpr int min_exponent10 = tosa::float_support::min_exponent10_v<min_exponent>;
+ static constexpr int max_exponent = this_float_t::exponent_bias + 1;
+ static constexpr int max_exponent10 = tosa::float_support::max_exponent10_v<max_exponent>;
+
+ static constexpr bool has_infinity = with_inf;
+ static constexpr bool has_quiet_NaN = has_nan;
+ static constexpr bool has_signaling_NaN = true;
+ static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent;
+ static constexpr bool has_denorm_loss = false;
+
+ static constexpr auto infinity() noexcept
+ {
+ if constexpr (with_inf)
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0);
+ }
+ else
+ {
+ return this_float_t::from_bits(false, 0, 0);
+ }
+ }
+
+ static constexpr auto quiet_NaN() noexcept
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1,
+ 1 << (this_float_t::n_significand_bits - 1) | 1);
+ }
+
+ static constexpr auto signaling_NaN() noexcept
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1);
+ }
+
+ static constexpr auto denorm_min() noexcept
+ {
+ return this_float_t::from_bits(false, 0, 1);
+ }
+
+ static constexpr bool is_iec559 = false;
+ static constexpr bool is_bounded = false;
+ static constexpr bool is_modulo = false;
+
+ static constexpr bool traps = false;
+ static constexpr bool tinyness_before = false;
+ static constexpr float_round_style round_style = round_to_nearest;
+};
+
+} // namespace std
+
+#endif // TOSA_FLOAT_UTILS_H_
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 20f6993..0798256 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -1008,15 +1008,20 @@ inline ::flatbuffers::Offset<TransposeConvAttribute> CreateTransposeConvAttribut
struct PadAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef PadAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_PAD_CONST = 4
+ VT_PAD_CONST = 4,
+ VT_TYPE = 6
};
const ::flatbuffers::Vector<uint8_t> *pad_const() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_PAD_CONST);
}
+ tosa::DType type() const {
+ return static_cast<tosa::DType>(GetField<uint32_t>(VT_TYPE, 0));
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_PAD_CONST) &&
verifier.VerifyVector(pad_const()) &&
+ VerifyField<uint32_t>(verifier, VT_TYPE, 4) &&
verifier.EndTable();
}
};
@@ -1028,6 +1033,9 @@ struct PadAttributeBuilder {
void add_pad_const(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const) {
fbb_.AddOffset(PadAttribute::VT_PAD_CONST, pad_const);
}
+ void add_type(tosa::DType type) {
+ fbb_.AddElement<uint32_t>(PadAttribute::VT_TYPE, static_cast<uint32_t>(type), 0);
+ }
explicit PadAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1041,20 +1049,24 @@ struct PadAttributeBuilder {
inline ::flatbuffers::Offset<PadAttribute> CreatePadAttribute(
::flatbuffers::FlatBufferBuilder &_fbb,
- ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const = 0) {
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> pad_const = 0,
+ tosa::DType type = tosa::DType_UNKNOWN) {
PadAttributeBuilder builder_(_fbb);
+ builder_.add_type(type);
builder_.add_pad_const(pad_const);
return builder_.Finish();
}
inline ::flatbuffers::Offset<PadAttribute> CreatePadAttributeDirect(
::flatbuffers::FlatBufferBuilder &_fbb,
- const std::vector<uint8_t> *pad_const = nullptr) {
+ const std::vector<uint8_t> *pad_const = nullptr,
+ tosa::DType type = tosa::DType_UNKNOWN) {
if (pad_const) { _fbb.ForceVectorAlignment(pad_const->size(), sizeof(uint8_t), 8); }
auto pad_const__ = pad_const ? _fbb.CreateVector<uint8_t>(*pad_const) : 0;
return tosa::CreatePadAttribute(
_fbb,
- pad_const__);
+ pad_const__,
+ type);
}
struct AxisAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
@@ -1193,7 +1205,8 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef ClampAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_MIN_VAL = 4,
- VT_MAX_VAL = 6
+ VT_MAX_VAL = 6,
+ VT_TYPE = 8
};
const ::flatbuffers::Vector<uint8_t> *min_val() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MIN_VAL);
@@ -1201,12 +1214,16 @@ struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
const ::flatbuffers::Vector<uint8_t> *max_val() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MAX_VAL);
}
+ tosa::DType type() const {
+ return static_cast<tosa::DType>(GetField<uint32_t>(VT_TYPE, 0));
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_MIN_VAL) &&
verifier.VerifyVector(min_val()) &&
VerifyOffset(verifier, VT_MAX_VAL) &&
verifier.VerifyVector(max_val()) &&
+ VerifyField<uint32_t>(verifier, VT_TYPE, 4) &&
verifier.EndTable();
}
};
@@ -1221,6 +1238,9 @@ struct ClampAttributeBuilder {
void add_max_val(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val) {
fbb_.AddOffset(ClampAttribute::VT_MAX_VAL, max_val);
}
+ void add_type(tosa::DType type) {
+ fbb_.AddElement<uint32_t>(ClampAttribute::VT_TYPE, static_cast<uint32_t>(type), 0);
+ }
explicit ClampAttributeBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1235,8 +1255,10 @@ struct ClampAttributeBuilder {
inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttribute(
::flatbuffers::FlatBufferBuilder &_fbb,
::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> min_val = 0,
- ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val = 0) {
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> max_val = 0,
+ tosa::DType type = tosa::DType_UNKNOWN) {
ClampAttributeBuilder builder_(_fbb);
+ builder_.add_type(type);
builder_.add_max_val(max_val);
builder_.add_min_val(min_val);
return builder_.Finish();
@@ -1245,7 +1267,8 @@ inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttribute(
inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttributeDirect(
::flatbuffers::FlatBufferBuilder &_fbb,
const std::vector<uint8_t> *min_val = nullptr,
- const std::vector<uint8_t> *max_val = nullptr) {
+ const std::vector<uint8_t> *max_val = nullptr,
+ tosa::DType type = tosa::DType_UNKNOWN) {
if (min_val) { _fbb.ForceVectorAlignment(min_val->size(), sizeof(uint8_t), 8); }
auto min_val__ = min_val ? _fbb.CreateVector<uint8_t>(*min_val) : 0;
if (max_val) { _fbb.ForceVectorAlignment(max_val->size(), sizeof(uint8_t), 8); }
@@ -1253,7 +1276,8 @@ inline ::flatbuffers::Offset<ClampAttribute> CreateClampAttributeDirect(
return tosa::CreateClampAttribute(
_fbb,
min_val__,
- max_val__);
+ max_val__,
+ type);
}
struct RescaleAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 5c53f57..f5f9e58 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -18,6 +18,7 @@
#include "attribute.h"
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
+#include "float_utils.h"
#include "numpy_utils.h"
#include "tosa_generated.h"
#include <cstdint>
@@ -411,6 +412,9 @@ public:
tosa_err_t LoadFileSchema(const char* schema_filename);
// data format conversion. little-endian.
+ static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
@@ -421,6 +425,9 @@ public:
static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
static tosa_err_t
ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index e6ab3d0..298907e 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -17,6 +17,7 @@ import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
+import struct
from enum import IntEnum, unique
from tosa import (
TosaGraph,
@@ -204,7 +205,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.bools.append((a.AddLocalBound, local_bound))
self.ints.append((a.AddAccType, acc_type))
- def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
+ def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype):
from tosa import PadAttribute as a, Attribute
self.utype = Attribute.Attribute().PadAttribute
@@ -216,6 +217,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
)
self.floats.append((a.AddPadConst, serialized_pad_const_val))
+ self.ints.append((a.AddType, dtype))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -236,7 +238,9 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.int16vecs.append((a.AddBorder, border))
self.ints.append((a.AddMode, mode))
- def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes):
+ def ClampAttribute(
+ self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype
+ ):
from tosa import ClampAttribute as a, Attribute
self.utype = Attribute.Attribute().ClampAttribute
@@ -252,6 +256,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.floats.append((a.AddMinVal, serialized_min_val))
self.floats.append((a.AddMaxVal, serialized_max_val))
+ self.ints.append((a.AddType, dtype))
def RescaleAttribute(
self,
@@ -439,87 +444,7 @@ class TosaSerializerTensor:
fb_name = builder.CreateString(self.name)
fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
if self.data:
- u8_data = list()
- # little endianess
- if self.dtype == DType.BOOL:
- for val in self.data:
- val_u8 = np.uint8(val)
- u8_data.append(val_u8)
- elif self.dtype == DType.INT4:
- in_size = len(self.data)
- out_size = (in_size + 1) // 2
- for i in range(out_size):
- val_0 = self.data[2 * i]
- if (2 * i + 1) < in_size:
- val_1 = self.data[2 * i + 1]
- else:
- val_1 = 0
- val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
- val_u8 = np.uint8(val_i8)
- u8_data.append(val_u8)
- elif self.dtype == DType.INT8:
- for val in self.data:
- val_u8 = np.array(val).astype(dtype=np.uint8)
- u8_data.append(val_u8)
- elif self.dtype == DType.INT16:
- for val in self.data:
- val_u16 = np.array(val).astype(dtype=np.uint16)
- b0 = val_u16 & ByteMask
- b1 = (val_u16 >> np.uint16(8)) & ByteMask
- u8_data.extend([b0, b1])
- elif self.dtype == DType.INT32:
- for val in self.data:
- val_u32 = np.array(val).astype(dtype=np.uint32)
- b0 = val_u32 & ByteMask
- b1 = (val_u32 >> np.uint32(8)) & ByteMask
- b2 = (val_u32 >> np.uint32(16)) & ByteMask
- b3 = (val_u32 >> np.uint32(24)) & ByteMask
- u8_data.extend([b0, b1, b2, b3])
- elif self.dtype == DType.INT48:
- for val in self.data:
- val_u64 = np.uint64(val)
- b0 = val_u64 & ByteMask
- b1 = (val_u64 >> np.uint64(8)) & ByteMask
- b2 = (val_u64 >> np.uint64(16)) & ByteMask
- b3 = (val_u64 >> np.uint64(24)) & ByteMask
- b4 = (val_u64 >> np.uint64(32)) & ByteMask
- b5 = (val_u64 >> np.uint64(40)) & ByteMask
- u8_data.extend([b0, b1, b2, b3, b4, b5])
- elif self.dtype == DType.SHAPE:
- for val in self.data:
- val_u64 = np.uint64(val)
- b0 = val_u64 & ByteMask
- b1 = (val_u64 >> np.uint64(8)) & ByteMask
- b2 = (val_u64 >> np.uint64(16)) & ByteMask
- b3 = (val_u64 >> np.uint64(24)) & ByteMask
- b4 = (val_u64 >> np.uint64(32)) & ByteMask
- b5 = (val_u64 >> np.uint64(40)) & ByteMask
- b6 = (val_u64 >> np.uint64(48)) & ByteMask
- b7 = (val_u64 >> np.uint64(56)) & ByteMask
- u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7])
- elif self.dtype == DType.FP16:
- np_arr = np.array(self.data, dtype=np.float16)
- u8_data.extend(np_arr.view(np.uint8))
- elif (
- self.dtype == DType.FP32
- or self.dtype == DType.BF16
- or self.dtype == DType.FP8E4M3
- or self.dtype == DType.FP8E5M2
- ):
- # for val in self.data:
- # b = struct.pack("!f", val)
- # u8_data.extend([b[3], b[2], b[1], b[0]])
- np_arr = np.array(self.data, dtype=np.float32)
- u8_data.extend(np_arr.view(np.uint8))
- elif self.dtype == TosaDType.DType:
- # Serialize DType enum data as uint8 bytes
- for val in self.data:
- np_arr = np.array(self.data, dtype=np.uint32)
- u8_data.extend(np_arr.view(np.uint8))
- else:
- raise Exception(
- "unsupported data type {}".format(DTypeNames[self.dtype])
- )
+ u8_data = TosaSerializer.convertDataToUint8Vec(self.dtype, self.data)
fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
TosaTensor.Start(builder)
@@ -958,3 +883,105 @@ class TosaSerializer:
return val
else:
return [val]
+
+ @staticmethod
+ def convertDataToUint8Vec(dtype, data):
+ u8_data = list()
+ # little endianess
+ if dtype == DType.BOOL:
+ for val in data:
+ val_u8 = np.uint8(val)
+ u8_data.append(val_u8)
+ elif dtype == DType.INT4:
+ in_size = len(data)
+ out_size = (in_size + 1) // 2
+ for i in range(out_size):
+ val_0 = data[2 * i]
+ if (2 * i + 1) < in_size:
+ val_1 = data[2 * i + 1]
+ else:
+ val_1 = 0
+ val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
+ val_u8 = np.uint8(val_i8)
+ u8_data.append(val_u8)
+ elif dtype == DType.INT8:
+ for val in data:
+ val_u8 = np.array(val).astype(dtype=np.uint8)
+ u8_data.append(val_u8)
+ elif dtype == DType.INT16:
+ for val in data:
+ val_u16 = np.array(val).astype(dtype=np.uint16)
+ b0 = val_u16 & ByteMask
+ b1 = (val_u16 >> np.uint16(8)) & ByteMask
+ u8_data.extend([b0, b1])
+ elif dtype == DType.INT32:
+ for val in data:
+ val_u32 = np.array(val).astype(dtype=np.uint32)
+ b0 = val_u32 & ByteMask
+ b1 = (val_u32 >> np.uint32(8)) & ByteMask
+ b2 = (val_u32 >> np.uint32(16)) & ByteMask
+ b3 = (val_u32 >> np.uint32(24)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3])
+ elif dtype == DType.INT48:
+ for val in data:
+ val_u64 = np.uint64(val)
+ b0 = val_u64 & ByteMask
+ b1 = (val_u64 >> np.uint64(8)) & ByteMask
+ b2 = (val_u64 >> np.uint64(16)) & ByteMask
+ b3 = (val_u64 >> np.uint64(24)) & ByteMask
+ b4 = (val_u64 >> np.uint64(32)) & ByteMask
+ b5 = (val_u64 >> np.uint64(40)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3, b4, b5])
+ elif dtype == DType.SHAPE:
+ for val in data:
+ val_u64 = np.uint64(val)
+ b0 = val_u64 & ByteMask
+ b1 = (val_u64 >> np.uint64(8)) & ByteMask
+ b2 = (val_u64 >> np.uint64(16)) & ByteMask
+ b3 = (val_u64 >> np.uint64(24)) & ByteMask
+ b4 = (val_u64 >> np.uint64(32)) & ByteMask
+ b5 = (val_u64 >> np.uint64(40)) & ByteMask
+ b6 = (val_u64 >> np.uint64(48)) & ByteMask
+ b7 = (val_u64 >> np.uint64(56)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7])
+ elif dtype == DType.FP16:
+ np_arr = np.array(data, dtype=np.float16)
+ u8_data.extend(np_arr.view(np.uint8))
+ elif dtype == DType.FP32:
+ # for val in data:
+ # b = struct.pack("!f", val)
+ # u8_data.extend([b[3], b[2], b[1], b[0]])
+ np_arr = np.array(data, dtype=np.float32)
+ u8_data.extend(np_arr.view(np.uint8))
+ elif dtype == DType.BF16:
+ for val in data:
+ # convert val to little endian byte arrays b
+ b = struct.pack("<f", val)
+ # val => [ b[3], b[2], b[1], b[0] ]
+ # keep only most significant 2 bytes for bf16
+ # in little endian ordering
+ u8_data.extend([b[2], b[3]])
+ elif dtype == DType.FP8E4M3:
+ for val in data:
+ # convert val to fp8_bits then to single byte
+ f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
+ f32_bits = f"{f32_as_int:032b}"
+ fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12]
+ fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
+ u8_data.extend(fp8_bytes)
+ elif dtype == DType.FP8E5M2:
+ for val in data:
+ # convert val to fp8_bits then to single byte
+ f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
+ f32_bits = f"{f32_as_int:032b}"
+ fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11]
+ fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
+ u8_data.extend(fp8_bytes)
+ elif dtype == TosaDType.DType:
+ # Serialize DType enum data as uint8 bytes
+ for val in data:
+ np_arr = np.array(data, dtype=np.uint32)
+ u8_data.extend(np_arr.view(np.uint8))
+ else:
+ raise Exception("unsupported data type {}".format(DTypeNames[dtype]))
+ return u8_data
diff --git a/python/tosa/ClampAttribute.py b/python/tosa/ClampAttribute.py
index 6a41498..1189acb 100644
--- a/python/tosa/ClampAttribute.py
+++ b/python/tosa/ClampAttribute.py
@@ -82,8 +82,15 @@ class ClampAttribute(object):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
return o == 0
+ # ClampAttribute
+ def Type(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+ return 0
+
def ClampAttributeStart(builder):
- builder.StartObject(2)
+ builder.StartObject(3)
def Start(builder):
ClampAttributeStart(builder)
@@ -112,6 +119,12 @@ def ClampAttributeStartMaxValVector(builder, numElems):
def StartMaxValVector(builder, numElems: int) -> int:
return ClampAttributeStartMaxValVector(builder, numElems)
+def ClampAttributeAddType(builder, type):
+ builder.PrependUint32Slot(2, type, 0)
+
+def AddType(builder, type):
+ ClampAttributeAddType(builder, type)
+
def ClampAttributeEnd(builder):
return builder.EndObject()
diff --git a/python/tosa/PadAttribute.py b/python/tosa/PadAttribute.py
index 301bf17..c4084dc 100644
--- a/python/tosa/PadAttribute.py
+++ b/python/tosa/PadAttribute.py
@@ -55,8 +55,15 @@ class PadAttribute(object):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
+ # PadAttribute
+ def Type(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
+ if o != 0:
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
+ return 0
+
def PadAttributeStart(builder):
- builder.StartObject(1)
+ builder.StartObject(2)
def Start(builder):
PadAttributeStart(builder)
@@ -73,6 +80,12 @@ def PadAttributeStartPadConstVector(builder, numElems):
def StartPadConstVector(builder, numElems: int) -> int:
return PadAttributeStartPadConstVector(builder, numElems)
+def PadAttributeAddType(builder, type):
+ builder.PrependUint32Slot(1, type, 0)
+
+def AddType(builder, type):
+ PadAttributeAddType(builder, type)
+
def PadAttributeEnd(builder):
return builder.EndObject()
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 79b83b1..7b5948b 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -185,6 +185,7 @@ table TransposeConvAttribute {
table PadAttribute {
pad_const: [ubyte] (force_align: 8);
+ type: DType;
}
table AxisAttribute {
@@ -201,6 +202,7 @@ table ResizeAttribute {
table ClampAttribute {
min_val: [ubyte] (force_align: 8);
max_val: [ubyte] (force_align: 8);
+ type: DType;
}
table RescaleAttribute {
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 749a3c8..85625cd 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -19,6 +19,9 @@
#include <iostream>
using namespace tosa;
+using fp8e4m3 = tosa::float_t<int8_t, 4, true, true, false>;
+using fp8e5m2 = tosa::float_t<int8_t, 5, true, true, true>;
+
TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
@@ -747,6 +750,51 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
}
}
+tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+{
+ // Note: Converts fp32->bf16 by ignoring the least significant 16 bits
+ out.clear();
+ for (auto val : in)
+ {
+ uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val);
+ uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF;
+ uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF;
+ // little endian: byte2 followed by byte3
+ out.push_back(f32_byte2);
+ out.push_back(f32_byte3);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+{
+ // Note: Converts fp32->FP8E4M3 before converting to unint8_t
+ out.clear();
+ for (auto val : in)
+ {
+ auto f8 = static_cast<fp8e4m3>(val);
+ uint8_t b8 = f8.bits();
+ out.push_back(b8);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+{
+ // Note: Converts fp32->FP8E5M2 before converting to uint8_t
+ out.clear();
+ for (auto val : in)
+ {
+ auto f8 = static_cast<fp8e5m2>(val);
+ uint8_t b8 = f8.bits();
+ out.push_back(b8);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->fp16 before converting to uint8_t
@@ -896,6 +944,78 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
return TOSA_OK;
}
+tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<float>& out)
+{
+ // Note: bf16 values returned in fp32 type
+ out.clear();
+ if (in.size() < out_size * sizeof(int16_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toBF16(): uint8 buffer size %ld must >= target size %ld\n",
+ in.size(), out_size * sizeof(int16_t));
+ return TOSA_USER_ERROR;
+ }
+
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ uint32_t f32_byte2 = in[i * sizeof(int16_t)];
+ uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1];
+ uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24);
+
+ // Reinterpret u32 bytes as fp32
+ float val_f32 = *(float*)&val_u32;
+ out.push_back(val_f32);
+ }
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<float>& out)
+{
+ // Note: FP8E4M3 values returned in fp32 type
+ out.clear();
+ if (in.size() < out_size * sizeof(int8_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
+ out_size * sizeof(int8_t));
+ return TOSA_USER_ERROR;
+ }
+
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e4m3::from_bits(bits);
+ float val_f32 = static_cast<float>(f8);
+ out.push_back(val_f32);
+ }
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<float>& out)
+{
+ // Note: FP8E5M2 values returned in fp32 type
+ out.clear();
+ if (in.size() < out_size * sizeof(int8_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
+ out_size * sizeof(int8_t));
+ return TOSA_USER_ERROR;
+ }
+
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e5m2::from_bits(bits);
+ float val_f32 = static_cast<float>(f8);
+ out.push_back(val_f32);
+ }
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& in,
uint32_t out_size,
std::vector<half_float::half>& out)