diff options
-rw-r--r-- | include/attribute.def | 5 | ||||
-rw-r--r-- | include/cfloat.h | 44 | ||||
-rw-r--r-- | include/numpy_utils.h | 17 | ||||
-rw-r--r-- | include/tosa_generated.h | 50 | ||||
-rw-r--r-- | include/tosa_serialization_handler.h | 12 | ||||
-rw-r--r-- | python/serializer/tosa_serializer.py | 48 | ||||
-rw-r--r-- | python/tosa/ResizeAttribute.py | 117 | ||||
-rw-r--r-- | schema/tosa.fbs | 6 | ||||
-rw-r--r-- | src/numpy_utils.cpp | 29 | ||||
-rw-r--r-- | src/tosa_serialization_handler.cpp | 68 |
10 files changed, 132 insertions, 264 deletions
diff --git a/include/attribute.def b/include/attribute.def index 0e97629..52d5179 100644 --- a/include/attribute.def +++ b/include/attribute.def @@ -57,10 +57,7 @@ DEF_ATTRIBUTE(Pad, 1, DEF_ATTRIBUTE(Axis, 1, int32_t, S, axis) -DEF_ATTRIBUTE(Resize, 4, - int16_t, V, scale, - int16_t, V, offset, - int16_t, V, border, +DEF_ATTRIBUTE(Resize, 1, ResizeMode, S, mode) DEF_ATTRIBUTE(Clamp, 2, diff --git a/include/cfloat.h b/include/cfloat.h index 0cf4896..cbbe09a 100644 --- a/include/cfloat.h +++ b/include/cfloat.h @@ -211,10 +211,33 @@ public: if (in.is_nan() || in.is_infinity()) { + // The mapping of infinity to the destination type depends upon + // the overflow mode and the features of the destination type. + // OVERFLOW mode is the "expected" behaviour, in which exception + // values (NaN and infinity) map to themselves in the + // destination type (assuming they exist). In SATURATION mode, + // infinity maps to the largest absolute value of the + // destination type _even if_ an infinity encoding is available. + // See the FP8 specification document. + // + // By default, exceptional values are encoded with an all-1 + // exponent field. new_exponent_bits = (UINT64_C(1) << out_exp_bits) - 1; if (in.is_nan()) { + // NaN always maps to NaN if it's available. + // + // NB: if the type has both NaN AND Infinity support, then + // the entirety of the significand can be used to encode + // different values of NaN (excepting significand = 0, + // which is reserved for infinity). This makes it possible + // to encode both quiet and signalling varieties. + // Generally, the LSB of the significand represents "not + // quiet". However, when there is only 1 NaN encoding + // (which is generally the case when infinity is not + // supported), then there cannot be separate quiet and + // signalling varieties of NaN. if constexpr (out_type::has_inf) { // Copy across the `not_quiet bit`; set the LSB. @@ -228,17 +251,18 @@ public: new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; } } - else if constexpr (out_type::has_inf && overflow_mode == OverflowMode::Saturate) + else if constexpr (overflow_mode == OverflowMode::Saturate) { - new_exponent_bits -= 1; - new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; - } - else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Saturate) - { - new_significand = (UINT64_C(1) << out_type::n_significand_bits) - (out_type::has_nan ? 2 : 1); + // In SATURATE mode, infinity in the input maps to the + // largest absolute value in the output type; even if + // infinity is available. This is in compliance with Table 3 + // of the FP8 specification. + return out_type::max(sign_bit); } else if constexpr (!out_type::has_inf && overflow_mode == OverflowMode::Overflow) { + // In OVERFLOW mode, infinities in the input type map to NaN + // in the output type, if infinity is not available. new_significand = (UINT64_C(1) << out_type::n_significand_bits) - 1; } } @@ -492,20 +516,20 @@ public: { // Where we have NaN and Infinity, exponents all `1` corresponds // to some of these values. - return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1); + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 2, (UINT64_C(1) << n_significand_bits) - 1); } else if constexpr (has_nan || has_inf) { // Where we have either NaN or infinity (but not both), // exponents all `1` AND significand all `1` corresponds to the // special value. - return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2); + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 2); } else { // With no special values to encode, the maximum value is // encoded as all `1`s. - return from_bits(false, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1); + return from_bits(sign, (UINT64_C(1) << n_exponent_bits) - 1, (UINT64_C(1) << n_significand_bits) - 1); } } diff --git a/include/numpy_utils.h b/include/numpy_utils.h index 60cf77e..ade2f2d 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -24,8 +24,13 @@ #include <cstring> #include <vector> +#include "cfloat.h" #include "half.hpp" +using bf16 = ct::cfloat<int16_t, 8, true, true, true>; +using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; +using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; + class NumpyUtilities { public: @@ -85,6 +90,18 @@ public: { return "'<f2'"; } + if (std::is_same<T, bf16>::value) + { + return "'<V2'"; + } + if (std::is_same<T, fp8e4m3>::value) + { + return "'<V1'"; + } + if (std::is_same<T, fp8e5m2>::value) + { + return "'<f1'"; + } assert(false && "unsupported Dtype"); }; diff --git a/include/tosa_generated.h b/include/tosa_generated.h index c907c89..61bc465 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -1087,31 +1087,13 @@ inline ::flatbuffers::Offset<AxisAttribute> CreateAxisAttribute( struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef ResizeAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_SCALE = 4, - VT_OFFSET = 6, - VT_BORDER = 8, VT_MODE = 10 }; - const ::flatbuffers::Vector<int16_t> *scale() const { - return GetPointer<const ::flatbuffers::Vector<int16_t> *>(VT_SCALE); - } - const ::flatbuffers::Vector<int16_t> *offset() const { - return GetPointer<const ::flatbuffers::Vector<int16_t> *>(VT_OFFSET); - } - const ::flatbuffers::Vector<int16_t> *border() const { - return GetPointer<const ::flatbuffers::Vector<int16_t> *>(VT_BORDER); - } tosa::ResizeMode mode() const { return static_cast<tosa::ResizeMode>(GetField<uint32_t>(VT_MODE, 0)); } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_SCALE) && - verifier.VerifyVector(scale()) && - VerifyOffset(verifier, VT_OFFSET) && - verifier.VerifyVector(offset()) && - VerifyOffset(verifier, VT_BORDER) && - verifier.VerifyVector(border()) && VerifyField<uint32_t>(verifier, VT_MODE, 4) && verifier.EndTable(); } @@ -1121,15 +1103,6 @@ struct ResizeAttributeBuilder { typedef ResizeAttribute Table; ::flatbuffers::FlatBufferBuilder &fbb_; ::flatbuffers::uoffset_t start_; - void add_scale(::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> scale) { - fbb_.AddOffset(ResizeAttribute::VT_SCALE, scale); - } - void add_offset(::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> offset) { - fbb_.AddOffset(ResizeAttribute::VT_OFFSET, offset); - } - void add_border(::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> border) { - fbb_.AddOffset(ResizeAttribute::VT_BORDER, border); - } void add_mode(tosa::ResizeMode mode) { fbb_.AddElement<uint32_t>(ResizeAttribute::VT_MODE, static_cast<uint32_t>(mode), 0); } @@ -1146,35 +1119,12 @@ struct ResizeAttributeBuilder { inline ::flatbuffers::Offset<ResizeAttribute> CreateResizeAttribute( ::flatbuffers::FlatBufferBuilder &_fbb, - ::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> scale = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> offset = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<int16_t>> border = 0, tosa::ResizeMode mode = tosa::ResizeMode_UNKNOWN) { ResizeAttributeBuilder builder_(_fbb); builder_.add_mode(mode); - builder_.add_border(border); - builder_.add_offset(offset); - builder_.add_scale(scale); return builder_.Finish(); } -inline ::flatbuffers::Offset<ResizeAttribute> CreateResizeAttributeDirect( - ::flatbuffers::FlatBufferBuilder &_fbb, - const std::vector<int16_t> *scale = nullptr, - const std::vector<int16_t> *offset = nullptr, - const std::vector<int16_t> *border = nullptr, - tosa::ResizeMode mode = tosa::ResizeMode_UNKNOWN) { - auto scale__ = scale ? _fbb.CreateVector<int16_t>(*scale) : 0; - auto offset__ = offset ? _fbb.CreateVector<int16_t>(*offset) : 0; - auto border__ = border ? _fbb.CreateVector<int16_t>(*border) : 0; - return tosa::CreateResizeAttribute( - _fbb, - scale__, - offset__, - border__, - mode); -} - struct ClampAttribute FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef ClampAttributeBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 139a476..c09a47d 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -412,9 +412,9 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. - static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out); @@ -425,9 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out); static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out); static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index c328aaf..c417fce 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np -import struct +from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2 from enum import IntEnum, unique from tosa import ( TosaGraph, @@ -225,15 +225,12 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddAxis, axis)) - def ResizeAttribute(self, scale, offset, border, mode): + def ResizeAttribute(self, mode): from tosa import ResizeAttribute as a, Attribute self.utype = Attribute.Attribute().ResizeAttribute self.optFcns = (a.Start, a.End) - self.int16vecs.append((a.AddScale, scale)) - self.int16vecs.append((a.AddOffset, offset)) - self.int16vecs.append((a.AddBorder, border)) self.ints.append((a.AddMode, mode)) def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes): @@ -392,13 +389,14 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if ( - dtype == DType.FP32 - or dtype == DType.BF16 - or dtype == DType.FP8E4M3 - or dtype == DType.FP8E5M2 - ): + if dtype == DType.FP32: fntype = np.float32 + elif dtype == DType.BF16: + fntype = bfloat16 + elif dtype == DType.FP8E4M3: + fntype = float8_e4m3fn + elif dtype == DType.FP8E5M2: + fntype = float8_e5m2 elif dtype == DType.FP16: fntype = np.float16 else: @@ -943,35 +941,19 @@ class TosaSerializer: np_arr = np.array(data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP32: - # for val in data: - # b = struct.pack("!f", val) - # u8_data.extend([b[3], b[2], b[1], b[0]]) np_arr = np.array(data, dtype=np.float32) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.BF16: - for val in data: - # convert val to little endian byte arrays b - b = struct.pack("<f", val) - # val => [ b[3], b[2], b[1], b[0] ] - # keep only most significant 2 bytes for bf16 - # in little endian ordering - u8_data.extend([b[2], b[3]]) + np_arr = np.array(data, dtype=bfloat16) + u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP8E4M3: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8) + u8_data.append(val_f8) elif dtype == DType.FP8E5M2: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8) + u8_data.append(val_f8) elif dtype == TosaDType.DType: # Serialize DType enum data as uint8 bytes for val in data: diff --git a/python/tosa/ResizeAttribute.py b/python/tosa/ResizeAttribute.py index 44f7d31..f2a6992 100644 --- a/python/tosa/ResizeAttribute.py +++ b/python/tosa/ResizeAttribute.py @@ -29,87 +29,6 @@ class ResizeAttribute(object): self._tab = flatbuffers.table.Table(buf, pos) # ResizeAttribute - def Scale(self, j): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - a = self._tab.Vector(o) - return self._tab.Get(flatbuffers.number_types.Int16Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2)) - return 0 - - # ResizeAttribute - def ScaleAsNumpy(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int16Flags, o) - return 0 - - # ResizeAttribute - def ScaleLength(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - if o != 0: - return self._tab.VectorLen(o) - return 0 - - # ResizeAttribute - def ScaleIsNone(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) - return o == 0 - - # ResizeAttribute - def Offset(self, j): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - a = self._tab.Vector(o) - return self._tab.Get(flatbuffers.number_types.Int16Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2)) - return 0 - - # ResizeAttribute - def OffsetAsNumpy(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int16Flags, o) - return 0 - - # ResizeAttribute - def OffsetLength(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - if o != 0: - return self._tab.VectorLen(o) - return 0 - - # ResizeAttribute - def OffsetIsNone(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) - return o == 0 - - # ResizeAttribute - def Border(self, j): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - if o != 0: - a = self._tab.Vector(o) - return self._tab.Get(flatbuffers.number_types.Int16Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2)) - return 0 - - # ResizeAttribute - def BorderAsNumpy(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - if o != 0: - return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int16Flags, o) - return 0 - - # ResizeAttribute - def BorderLength(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - if o != 0: - return self._tab.VectorLen(o) - return 0 - - # ResizeAttribute - def BorderIsNone(self): - o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) - return o == 0 - - # ResizeAttribute def Mode(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: @@ -122,42 +41,6 @@ def ResizeAttributeStart(builder): def Start(builder): ResizeAttributeStart(builder) -def ResizeAttributeAddScale(builder, scale): - builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(scale), 0) - -def AddScale(builder, scale): - ResizeAttributeAddScale(builder, scale) - -def ResizeAttributeStartScaleVector(builder, numElems): - return builder.StartVector(2, numElems, 2) - -def StartScaleVector(builder, numElems): - return ResizeAttributeStartScaleVector(builder, numElems) - -def ResizeAttributeAddOffset(builder, offset): - builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(offset), 0) - -def AddOffset(builder, offset): - ResizeAttributeAddOffset(builder, offset) - -def ResizeAttributeStartOffsetVector(builder, numElems): - return builder.StartVector(2, numElems, 2) - -def StartOffsetVector(builder, numElems): - return ResizeAttributeStartOffsetVector(builder, numElems) - -def ResizeAttributeAddBorder(builder, border): - builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(border), 0) - -def AddBorder(builder, border): - ResizeAttributeAddBorder(builder, border) - -def ResizeAttributeStartBorderVector(builder, numElems): - return builder.StartVector(2, numElems, 2) - -def StartBorderVector(builder, numElems): - return ResizeAttributeStartBorderVector(builder, numElems) - def ResizeAttributeAddMode(builder, mode): builder.PrependUint32Slot(3, mode, 0) diff --git a/schema/tosa.fbs b/schema/tosa.fbs index cad6db7..1a2d952 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -191,9 +191,9 @@ table AxisAttribute { } table ResizeAttribute { - scale: [int16]; - offset: [int16]; - border: [int16]; + scale: [int16] (deprecated); + offset: [int16] (deprecated); + border: [int16] (deprecated); mode: ResizeMode; } diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index e4171d7..7cf5f94 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3 while (isspace(*ptr)) ptr++; + // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but + // default NumPy does not understand this notation, which causes trouble + // when other code tries to open this file. + // To avoid this, '|u1' notation is used when the file is written, and the uint8 + // data is viewed as float8_e5m2 later when the file is read. + if (!strcmp(dtype_str, "'<f1'")) + dtype_str = "'|u1'"; + if (strcmp(ptr, dtype_str)) { return FILE_TYPE_MISMATCH; @@ -430,6 +438,13 @@ NumpyUtilities::NPError memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1); headerPos += sizeof(NUMPY_HEADER_STR) - 1; + // NumPy does not understand float8_e5m2, so change it to uint8 type, so that + // Python can read .npy files. + if (!strcmp(dtype_str, "'<f1'")) + { + dtype_str = "'|u1'"; + } + // Output the format dictionary // Hard-coded for I32 for now headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, @@ -438,7 +453,19 @@ NumpyUtilities::NPError // Add shape contents (if any - as this will be empty for rank 0) for (i = 0; i < shape.size(); i++) { - headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); + // Output NumPy file from tosa_refmodel_sut_run generates the shape information + // without a trailing comma when the rank is greater than 1. + if (i == 0) + { + if (shape.size() == 1) + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]); + else + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]); + } + else + { + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]); + } } // Close off the dictionary diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 0ce6211..74f66d8 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -19,9 +19,6 @@ #include <iostream> using namespace tosa; -using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; -using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; - TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector<int32_t>* shape, DType dtype, @@ -750,45 +747,41 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf) } } -tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->bf16 by ignoring the least significant 16 bits out.clear(); for (auto val : in) { - uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val); - uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF; - uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF; - // little endian: byte2 followed by byte3 - out.push_back(f32_byte2); - out.push_back(f32_byte3); + uint8_t bf16_byte0 = val.bits() & 0xFF; + uint8_t bf16_byte1 = (val.bits() >> 8) & 0xFF; + out.push_back(bf16_byte0); + out.push_back(bf16_byte1); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->FP8E4M3 before converting to unint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast<fp8e4m3>(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out) { // Note: Converts fp32->FP8E5M2 before converting to uint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast<fp8e5m2>(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); @@ -944,11 +937,9 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, - uint32_t out_size, - std::vector<float>& out) +tosa_err_t + TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out) { - // Note: bf16 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int16_t)) { @@ -959,22 +950,21 @@ tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& for (uint32_t i = 0; i < out_size; i++) { - uint32_t f32_byte2 = in[i * sizeof(int16_t)]; - uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; - uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + uint8_t bf16_byte0 = in[i * sizeof(int16_t)]; + uint8_t bf16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = (bf16_byte0) + (bf16_byte1 << 8); - // Reinterpret u32 bytes as fp32 - float val_f32 = *(float*)&val_u32; - out.push_back(val_f32); + // Reinterpret u16 bytes as bf16 + bf16 val_bf16 = *(bf16*)&val_u16; + out.push_back(val_bf16); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, - std::vector<float>& out) + std::vector<fp8e4m3>& out) { - // Note: FP8E4M3 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int8_t)) { @@ -985,17 +975,16 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_ for (uint32_t i = 0; i < out_size; i++) { - int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); - auto f8 = fp8e4m3::from_bits(bits); - float val_f32 = static_cast<float>(f8); - out.push_back(val_f32); + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e4m3::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, - std::vector<float>& out) + std::vector<fp8e5m2>& out) { // Note: FP8E5M2 values returned in fp32 type out.clear(); @@ -1008,10 +997,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_ for (uint32_t i = 0; i < out_size; i++) { - int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); - auto f8 = fp8e5m2::from_bits(bits); - float val_f32 = static_cast<float>(f8); - out.push_back(val_f32); + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e5m2::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } @@ -1031,9 +1019,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& for (uint32_t i = 0; i < out_size; i++) { - uint16_t f16_byte0 = in[i * sizeof(int16_t)]; - uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1]; - uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); + uint8_t f16_byte0 = in[i * sizeof(int16_t)]; + uint8_t f16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); // Reinterpret u16 byte as fp16 then convert to fp32 half_float::half val_f16 = *(half_float::half*)&val_u16; |