From ce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 21 Mar 2024 17:01:14 +0000 Subject: Add conversions of U8 to/from BF16 and FP8 Adds type to PadAttribute and ClampAttribute so their pad_const and max_val/min_val can be deserialized according to type Adds conversion functions of U8 arrays to/from BF16/FP8 values Also, refactor and expose TosaSerializer.convertDataToUint8Vec for converting dtype/data to uint8 list for serialization And modify convertDataToUint8Vec to serialize bf16 values into 2 bytes each, and serialize fp8 values into single bytes each. Signed-off-by: Tai Ly Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f --- src/tosa_serialization_handler.cpp | 120 +++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) (limited to 'src/tosa_serialization_handler.cpp') diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 749a3c8..85625cd 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -19,6 +19,9 @@ #include using namespace tosa; +using fp8e4m3 = tosa::float_t; +using fp8e5m2 = tosa::float_t; + TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector* shape, DType dtype, @@ -747,6 +750,51 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector& buf) } } +tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector& in, std::vector& out) +{ + // Note: Converts fp32->bf16 by ignoring the least significant 16 bits + out.clear(); + for (auto val : in) + { + uint32_t* val_u32 = reinterpret_cast(&val); + uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF; + uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF; + // little endian: byte2 followed by byte3 + out.push_back(f32_byte2); + out.push_back(f32_byte3); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector& in, std::vector& out) +{ + // Note: Converts fp32->FP8E4M3 before converting to unint8_t + out.clear(); + for (auto val : in) + { + auto f8 = static_cast(val); + uint8_t b8 = f8.bits(); + out.push_back(b8); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector& in, std::vector& out) +{ + // Note: Converts fp32->FP8E5M2 before converting to uint8_t + out.clear(); + for (auto val : in) + { + auto f8 = static_cast(val); + uint8_t b8 = f8.bits(); + out.push_back(b8); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector& in, std::vector& out) { // Note: Converts fp32->fp16 before converting to uint8_t @@ -896,6 +944,78 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector& in return TOSA_OK; } +tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector& in, + uint32_t out_size, + std::vector& out) +{ + // Note: bf16 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int16_t)) + { + printf("TosaSerializationHandler::ConvertU8toBF16(): uint8 buffer size %ld must >= target size %ld\n", + in.size(), out_size * sizeof(int16_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + uint32_t f32_byte2 = in[i * sizeof(int16_t)]; + uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; + uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + + // Reinterpret u32 bytes as fp32 + float val_f32 = *(float*)&val_u32; + out.push_back(val_f32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector& in, + uint32_t out_size, + std::vector& out) +{ + // Note: FP8E4M3 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int8_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int8_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + int8_t bits = static_cast(in[i * sizeof(int8_t)]); + auto f8 = fp8e4m3::from_bits(bits); + float val_f32 = static_cast(f8); + out.push_back(val_f32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector& in, + uint32_t out_size, + std::vector& out) +{ + // Note: FP8E5M2 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int8_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int8_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + int8_t bits = static_cast(in[i * sizeof(int8_t)]); + auto f8 = fp8e5m2::from_bits(bits); + float val_f32 = static_cast(f8); + out.push_back(val_f32); + } + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector& in, uint32_t out_size, std::vector& out) -- cgit v1.2.1