From 07098d64498723e889e8f8deaad4952020fa9450 Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Thu, 9 May 2024 06:00:31 +0000 Subject: Fix Bfloat16 data conversion for serialization Signed-off-by: Won Jeon Change-Id: I52f6fea3e8b4cd5ff0886ccfa12396a680558670 --- src/tosa_serialization_handler.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) (limited to 'src/tosa_serialization_handler.cpp') diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 76b2198..4516f7d 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -940,7 +940,6 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector& in tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector& in, uint32_t out_size, std::vector& out) { - // Note: bf16 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int16_t)) { @@ -951,13 +950,12 @@ tosa_err_t for (uint32_t i = 0; i < out_size; i++) { - uint32_t f32_byte2 = in[i * sizeof(int16_t)]; - uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; - uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + uint8_t bf16_byte0 = in[i * sizeof(int16_t)]; + uint8_t bf16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = (bf16_byte0) + (bf16_byte1 << 8); - // Reinterpret u32 bytes as fp32 - float val_f32 = *(float*)&val_u32; - float val_bf16 = static_cast(val_f32); + // Reinterpret u16 bytes as bf16 + bf16 val_bf16 = static_cast(val_u16); out.push_back(val_bf16); } return TOSA_OK; @@ -1021,9 +1019,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector& for (uint32_t i = 0; i < out_size; i++) { - uint16_t f16_byte0 = in[i * sizeof(int16_t)]; - uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1]; - uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); + uint8_t f16_byte0 = in[i * sizeof(int16_t)]; + uint8_t f16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); // Reinterpret u16 byte as fp16 then convert to fp32 half_float::half val_f16 = *(half_float::half*)&val_u16; -- cgit v1.2.1