aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/attribute.def5
-rw-r--r--include/cfloat.h44
-rw-r--r--include/numpy_utils.h17
-rw-r--r--include/tosa_generated.h69
-rw-r--r--include/tosa_serialization_handler.h16
5 files changed, 71 insertions, 80 deletions
diff --git a/include/attribute.def b/include/attribute.def
index 0e97629..52d5179 100644
--- a/include/attribute.def
+++ b/include/attribute.def
@@ -57,10 +57,7 @@ DEF_ATTRIBUTE(Pad, 1,
DEF_ATTRIBUTE(Axis, 1,
int32_t, S, axis)
-DEF_ATTRIBUTE(Resize, 4,
- int16_t, V, scale,
- int16_t, V, offset,
- int16_t, V, border,
+DEF_ATTRIBUTE(Resize, 1,
ResizeMode, S, mode)
DEF_ATTRIBUTE(Clamp, 2,
diff --git a/include/cfloat.h b/include/cfloat.h
index 0cf4896..cbbe09a 100644
--- a/include/cfloat.h
+++ b/include/cfloat.h
@@ -211,10 +211,33 @@ public:
if (in.is_nan() || in.is_infinity())
{
+ // The mapping of infinity to the destination type depends upon
+ // the overflow mode and the features of the destination type.
+ // OVERFLOW mode is the "expected" behaviour, in which exception
+ // values (NaN and infinity) map to themselves in the
+ // destination type (assuming they exist). In SATURATION mode,
+ // infinity maps to the largest absolute value of the
+ // destination type _even if_ an infinity encoding is available.
+ // See the FP8 specification document.
+ //
+ // By default, exceptional values are encoded with an all-1
+ // exponent field.
new_exponent_bits = (UINT64_C(1) << out_exp_bits) - 1;
if (in.is_nan())
{
+ // NaN always maps to NaN if it's available.
+ //
+ // NB: if the type has both NaN AND Infinity support, then
+ // the entirety of the significand can be used to encode
+ // different values of NaN (excepting significand = 0,
+ // which is reserved for infinity). This makes it possible
+ // to encode both quiet and signalling varieties.
+ // Generally, the LSB of the significand represents "not
+ // quiet". However, when there is only 1 NaN encoding
+ // (which is generally the case when infinity is not
+ // supported), then there cannot be separate quiet and
+ // signalling varieties of NaN.
if constexpr (out_type::has_inf)
{
// Copy across the `not_quiet bit`; set the LSB.
@@ -228,17 +251,18 @@ public:
new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
}
}
- else if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Saturate)
+ else if constexpr (overflow_mode == OverflowMode::Saturate)
{
- new_exponent_bits -= 1;
- new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
- }
- else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Saturate)
- {
- new_significand = (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1);
+ // In SATURATE mode, infinity in the input maps to the
+ // largest absolute value in the output type; even if
+ // infinity is available. This is in compliance with Table 3
+ // of the FP8 specification.
+ return out_type::max(sign_bit);
}
else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Overflow)
{
+ // In OVERFLOW mode, infinities in the input type map to NaN
+ // in the output type, if infinity is not available.
new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1;
}
}
@@ -492,20 +516,20 @@ public:
{
// Where we have NaN and Infinity, exponents all `1` corresponds
// to some of these values.
- return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1);
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1);
}
else if constexpr (has_nan || has_inf)
{
// Where we have either NaN or infinity (but not both),
// exponents all `1` AND significand all `1` corresponds to the
// special value.
- return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2);
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2);
}
else
{
// With no special values to encode, the maximum value is
// encoded as all `1`s.
- return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1);
+ return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1);
}
}
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 60cf77e..ade2f2d 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -24,8 +24,13 @@
#include <cstring>
#include <vector>
+#include "cfloat.h"
#include "half.hpp"
+using bf16 = ct::cfloat<int16_t, 8, true, true, true>;
+using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
+using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
+
class NumpyUtilities
{
public:
@@ -85,6 +90,18 @@ public:
{
return "'<f2'";
}
+ if (std::is_same<T, bf16>::value)
+ {
+ return "'<V2'";
+ }
+ if (std::is_same<T, fp8e4m3>::value)
+ {
+ return "'<V1'";
+ }
+ if (std::is_same<T, fp8e5m2>::value)
+ {
+ return "'<f1'";
+ }
assert(false && "unsupported Dtype");
};
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 1b5e164..994b72c 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -8,9 +8,9 @@
// Ensure the included flatbuffers.h is the same version as when this file was
// generated, otherwise it may not be compatible.
-static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
- FLATBUFFERS_VERSION_MINOR == 5 &&
- FLATBUFFERS_VERSION_REVISION == 26,
+static_assert(FLATBUFFERS_VERSION_MAJOR == 24 &&
+ FLATBUFFERS_VERSION_MINOR == 3 &&
+ FLATBUFFERS_VERSION_REVISION == 25,
"Non-compatible flatbuffers version included");
namespace tosa {
@@ -277,11 +277,12 @@ enum Op : uint32_t {
Op_DIV_SHAPE = 78,
Op_COS = 79,
Op_SIN = 80,
+ Op_CAST_STOCHASTIC = 81,
Op_MIN = Op_UNKNOWN,
- Op_MAX = Op_SIN
+ Op_MAX = Op_CAST_STOCHASTIC
};
-inline const Op (&EnumValuesOp())[81] {
+inline const Op (&EnumValuesOp())[82] {
static const Op values[] = {
Op_UNKNOWN,
Op_ARGMAX,
@@ -363,13 +364,14 @@ inline const Op (&EnumValuesOp())[81] {
Op_MUL_SHAPE,
Op_DIV_SHAPE,
Op_COS,
- Op_SIN
+ Op_SIN,
+ Op_CAST_STOCHASTIC
};
return values;
}
inline const char * const *EnumNamesOp() {
- static const char * const names[82] = {
+ static const char * const names[83] = {
"UNKNOWN",
"ARGMAX",
"AVG_POOL2D",
@@ -451,13 +453,14 @@ inline const char * const *EnumNamesOp() {
"DIV_SHAPE",
"COS",
"SIN",
+ "CAST_STOCHASTIC",
nullptr
};
return names;
}
inline const char *EnumNameOp(Op e) {
- if (::flatbuffers::IsOutRange(e, Op_UNKNOWN, Op_SIN)) return "";
+ if (::flatbuffers::IsOutRange(e, Op_UNKNOWN, Op_CAST_STOCHASTIC)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesOp()[index];
}
@@ -1087,31 +1090,13 @@ inline ::flatbuffers::Offset<AxisAttribute> CreateAxisAttribute(
struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef ResizeAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_SCALE = 4,
- VT_OFFSET = 6,
- VT_BORDER = 8,
VT_MODE = 10
};
- const ::flatbuffers::Vector<int16_t> *scale() const {
- return GetPointer<const ::flatbuffers::Vector<int16_t> *>(VT_SCALE);
- }
- const ::flatbuffers::Vector<int16_t> *offset() const {
- return GetPointer<const ::flatbuffers::Vector<int16_t> *>(VT_OFFSET);
- }
- const ::flatbuffers::Vector<int16_t> *border() const {
- return GetPointer<const ::flatbuffers::Vector<int16_t> *>(VT_BORDER);
- }
tosa::ResizeMode mode() const {
return static_cast<tosa::ResizeMode>(GetField<uint32_t>(VT_MODE, 0));
}
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
- VerifyOffset(verifier, VT_SCALE) &&
- verifier.VerifyVector(scale()) &&
- VerifyOffset(verifier, VT_OFFSET) &&
- verifier.VerifyVector(offset()) &&
- VerifyOffset(verifier, VT_BORDER) &&
- verifier.VerifyVector(border()) &&
VerifyField<uint32_t>(verifier, VT_MODE, 4) &&
verifier.EndTable();
}
@@ -1121,15 +1106,6 @@ struct ResizeAttributeBuilder {
typedef ResizeAttribute Table;
::flatbuffers::FlatBufferBuilder &fbb_;
::flatbuffers::uoffset_t start_;
- void add_scale(::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> scale) {
- fbb_.AddOffset(ResizeAttribute::VT_SCALE, scale);
- }
- void add_offset(::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> offset) {
- fbb_.AddOffset(ResizeAttribute::VT_OFFSET, offset);
- }
- void add_border(::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> border) {
- fbb_.AddOffset(ResizeAttribute::VT_BORDER, border);
- }
void add_mode(tosa::ResizeMode mode) {
fbb_.AddElement<uint32_t>(ResizeAttribute::VT_MODE, static_cast<uint32_t>(mode), 0);
}
@@ -1146,35 +1122,12 @@ struct ResizeAttributeBuilder {
inline ::flatbuffers::Offset<ResizeAttribute> CreateResizeAttribute(
::flatbuffers::FlatBufferBuilder &_fbb,
- ::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> scale = 0,
- ::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> offset = 0,
- ::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> border = 0,
tosa::ResizeMode mode = tosa::ResizeMode_UNKNOWN) {
ResizeAttributeBuilder builder_(_fbb);
builder_.add_mode(mode);
- builder_.add_border(border);
- builder_.add_offset(offset);
- builder_.add_scale(scale);
return builder_.Finish();
}
-inline ::flatbuffers::Offset<ResizeAttribute> CreateResizeAttributeDirect(
- ::flatbuffers::FlatBufferBuilder &_fbb,
- const std::vector<int16_t> *scale = nullptr,
- const std::vector<int16_t> *offset = nullptr,
- const std::vector<int16_t> *border = nullptr,
- tosa::ResizeMode mode = tosa::ResizeMode_UNKNOWN) {
- auto scale__ = scale ? _fbb.CreateVector<int16_t>(*scale) : 0;
- auto offset__ = offset ? _fbb.CreateVector<int16_t>(*offset) : 0;
- auto border__ = border ? _fbb.CreateVector<int16_t>(*border) : 0;
- return tosa::CreateResizeAttribute(
- _fbb,
- scale__,
- offset__,
- border__,
- mode);
-}
-
struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
typedef ClampAttributeBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 1f8310e..c09a47d 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -27,8 +27,8 @@
#include <vector>
// Keep version number in sync with the version default value with schema/tosa.fbs
-#define TOSA_VERSION_MAJOR 0
-#define TOSA_VERSION_MINOR 100
+#define TOSA_VERSION_MAJOR 1
+#define TOSA_VERSION_MINOR 1
#define TOSA_VERSION_PATCH 0
#define TOSA_VERSION_DRAFT true
#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
@@ -412,9 +412,9 @@ public:
tosa_err_t LoadFileSchema(const char* schema_filename);
// data format conversion. little-endian.
- static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
@@ -425,9 +425,9 @@ public:
static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
- static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
- static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out);
+ static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out);
+ static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out);
static tosa_err_t
ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);