aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/cfloat.h44
-rw-r--r--include/numpy_utils.h17
-rw-r--r--include/tosa_generated.h6
-rw-r--r--include/tosa_serialization_handler.h12
4 files changed, 60 insertions, 19 deletions
diff --git a/include/cfloat.h b/include/cfloat.h
index 0cf4896..cbbe09a 100644
--- a/include/cfloat.h
+++ b/include/cfloat.h
@@ -211,10 +211,33 @@ public:
if (in.is_nan() || in.is_infinity())
{
+ // The mapping of infinity to the destination type depends upon
+ // the overflow mode and the features of the destination type.
+ // OVERFLOW mode is the "expected" behaviour, in which exception
+ // values (NaN and infinity) map to themselves in the
+ // destination type (assuming they exist). In SATURATION mode,
+ // infinity maps to the largest absolute value of the
+ // destination type _even if_ an infinity encoding is available.
+ // See the FP8 specification document.
+ //
+ // By default, exceptional values are encoded with an all-1
+ // exponent field.
new_exponent_bits = (UINT64_C(1) << out_exp_bits) - 1;
if (in.is_nan())
{
+ // NaN always maps to NaN if it's available.
+ //
+ // NB: if the type has both NaN AND Infinity support, then
+ // the entirety of the significand can be used to encode
+ // different values of NaN (excepting significand = 0,
+ // which is reserved for infinity). This makes it possible
+ // to encode both quiet and signalling varieties.
+ // Generally, the LSB of the significand represents "not
+ // quiet". However, when there is only 1 NaN encoding
+ // (which is generally the case when infinity is not
+ // supported), then there cannot be separate quiet and
+ // signalling varieties of NaN.
if constexpr (out_type::has_inf)
{
// Copy across the `not_quiet bit`; set the LSB.
@@ -228,17 +251,18 @@ public:
new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
}
}
- else if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Saturate)
+ else if constexpr (overflow_mode == OverflowMode::Saturate)
{
- new_exponent_bits -= 1;
- new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
- }
- else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Saturate)
- {
- new_significand = (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1);
+ // In SATURATE mode, infinity in the input maps to the
+ // largest absolute value in the output type; even if
+ // infinity is available. This is in compliance with Table 3
+ // of the FP8 specification.
+ return out_type::max(sign_bit);
}
else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Overflow)
{
+ // In OVERFLOW mode, infinities in the input type map to NaN
+ // in the output type, if infinity is not available.
new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
}
}
@@ -492,20 +516,20 @@ public:
{
// Where we have NaN and Infinity, exponents all `1` corresponds
// to some of these values.
- return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1);
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1);
}
else if constexpr (has_nan || has_inf)
{
// Where we have either NaN or infinity (but not both),
// exponents all `1` AND significand all `1` corresponds to the
// special value.
- return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2);
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2);
}
else
{
// With no special values to encode, the maximum value is
// encoded as all `1`s.
- return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1);
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1);
}
}
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 60cf77e..ade2f2d 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -24,8 +24,13 @@
#include <cstring>
#include <vector>
+#include "cfloat.h"
#include "half.hpp"
+using bf16 = ct::cfloat<int16_t, 8, true, true, true>;
+using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
+using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
+
class NumpyUtilities
{
public:
@@ -85,6 +90,18 @@ public:
{
return "'<f2'";
}
+ if (std::is_same<T, bf16>::value)
+ {
+ return "'<V2'";
+ }
+ if (std::is_same<T, fp8e4m3>::value)
+ {
+ return "'<V1'";
+ }
+ if (std::is_same<T, fp8e5m2>::value)
+ {
+ return "'<f1'";
+ }
assert(false && "unsupported Dtype");
};
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 1b5e164..c907c89 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -8,9 +8,9 @@
// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
-static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
- FLATBUFFERS_VERSION_MINOR == 5 &&
- FLATBUFFERS_VERSION_REVISION == 26,
+static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
+ FLATBUFFERS_VERSION_MINOR == 3 &&
+ FLATBUFFERS_VERSION_REVISION == 7,
"Non-compatible flatbuffers version included");
namespace tosa {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 139a476..c09a47d 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -412,9 +412,9 @@ public:
tosa_err_t LoadFileSchema(const char* schema_filename);
// data format conversion. little-endian.
- static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
@@ -425,9 +425,9 @@ public:
static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out);
+ static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out);
+ static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out);
static tosa_err_t
ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);