diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/cfloat.h | 44 | ||||
-rw-r--r-- | include/numpy_utils.h | 17 | ||||
-rw-r--r-- | include/tosa_generated.h | 6 | ||||
-rw-r--r-- | include/tosa_serialization_handler.h | 12 |
4 files changed, 60 insertions, 19 deletions
diff --git a/include/cfloat.h b/include/cfloat.h index 0cf4896..cbbe09a 100644 --- a/include/cfloat.h +++ b/include/cfloat.h @@ -211,10 +211,33 @@ public: if (in.is_nan() || in.is_infinity()) { + // The mapping of infinity to the destination type depends upon + // the overflow mode and the features of the destination type. + // OVERFLOW mode is the "expected" behaviour, in which exception + // values (NaN and infinity) map to themselves in the + // destination type (assuming they exist). In SATURATION mode, + // infinity maps to the largest absolute value of the + // destination type _even if_ an infinity encoding is available. + // See the FP8 specification document. + // + // By default, exceptional values are encoded with an all-1 + // exponent field. new_exponent_bits = (UINT64_C(1) << out_exp_bits) - 1; if (in.is_nan()) { + // NaN always maps to NaN if it's available. + // + // NB: if the type has both NaN AND Infinity support, then + // the entirety of the significand can be used to encode + // different values of NaN (excepting significand = 0, + // which is reserved for infinity). This makes it possible + // to encode both quiet and signalling varieties. + // Generally, the LSB of the significand represents "not + // quiet". However, when there is only 1 NaN encoding + // (which is generally the case when infinity is not + // supported), then there cannot be separate quiet and + // signalling varieties of NaN. if constexpr (out_type::has_inf) { // Copy across the `not_quiet bit`; set the LSB. @@ -228,17 +251,18 @@ public: new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; } } - else if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Saturate) + else if constexpr (overflow_mode == OverflowMode::Saturate) { - new_exponent_bits -= 1; - new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; - } - else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Saturate) - { - new_significand = (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1); + // In SATURATE mode, infinity in the input maps to the + // largest absolute value in the output type; even if + // infinity is available. This is in compliance with Table 3 + // of the FP8 specification. + return out_type::max(sign_bit); } else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Overflow) { + // In OVERFLOW mode, infinities in the input type map to NaN + // in the output type, if infinity is not available. new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; } } @@ -492,20 +516,20 @@ public: { // Where we have NaN and Infinity, exponents all `1` corresponds // to some of these values. - return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1); + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1); } else if constexpr (has_nan || has_inf) { // Where we have either NaN or infinity (but not both), // exponents all `1` AND significand all `1` corresponds to the // special value. - return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2); + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2); } else { // With no special values to encode, the maximum value is // encoded as all `1`s. - return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1); + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1); } } diff --git a/include/numpy_utils.h b/include/numpy_utils.h index 60cf77e..ade2f2d 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -24,8 +24,13 @@ #include <cstring> #include <vector> +#include "cfloat.h" #include "half.hpp" +using bf16 = ct::cfloat<int16_t, 8, true, true, true>; +using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; +using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; + class NumpyUtilities { public: @@ -85,6 +90,18 @@ public: { return "'<f2'"; } + if (std::is_same<T, bf16>::value) + { + return "'<V2'"; + } + if (std::is_same<T, fp8e4m3>::value) + { + return "'<V1'"; + } + if (std::is_same<T, fp8e5m2>::value) + { + return "'<f1'"; + } assert(false && "unsupported Dtype"); }; diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 1b5e164..c907c89 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -8,9 +8,9 @@ // Ensure the included flatbuffers.h is the same version as when this file was // generated, otherwise it may not be compatible. -static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && - FLATBUFFERS_VERSION_MINOR == 5 && - FLATBUFFERS_VERSION_REVISION == 26, +static_assert(FLATBUFFERS_VERSION_MAJOR == 24 && + FLATBUFFERS_VERSION_MINOR == 3 && + FLATBUFFERS_VERSION_REVISION == 7, "Non-compatible flatbuffers version included"); namespace tosa { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 139a476..c09a47d 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -412,9 +412,9 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. - static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out); @@ -425,9 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out); static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out); static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); |