diff options
author | Won Jeon <won.jeon@arm.com> | 2024-05-09 06:00:31 +0000 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2024-05-09 18:21:17 +0000 |
commit | 07098d64498723e889e8f8deaad4952020fa9450 (patch) | |
tree | 94f4bd6a070c73bbd629118fe6bfd8390bdf7213 | |
parent | b386815fcf36092be832281821af7ad9f2119e07 (diff) | |
download | serialization_lib-07098d64498723e889e8f8deaad4952020fa9450.tar.gz |
Fix Bfloat16 data conversion for serialization
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I52f6fea3e8b4cd5ff0886ccfa12396a680558670
-rw-r--r-- | python/serializer/tosa_serializer.py | 5 | ||||
-rw-r--r-- | src/tosa_serialization_handler.cpp | 18 |
2 files changed, 10 insertions, 13 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 7122216..34178c5 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -947,9 +947,8 @@ class TosaSerializer: np_arr = np.array(data, dtype=np.float32) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.BF16: - for val in data: - np_arr = np.array(data, dtype=bfloat16) - u8_data.extend(np_arr.view(np.uint8)) + np_arr = np.array(data, dtype=bfloat16) + u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP8E4M3: for val in data: val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8) diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 76b2198..4516f7d 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -940,7 +940,6 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out) { - // Note: bf16 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int16_t)) { @@ -951,13 +950,12 @@ tosa_err_t for (uint32_t i = 0; i < out_size; i++) { - uint32_t f32_byte2 = in[i * sizeof(int16_t)]; - uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; - uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + uint8_t bf16_byte0 = in[i * sizeof(int16_t)]; + uint8_t bf16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = (bf16_byte0) + (bf16_byte1 << 8); - // Reinterpret u32 bytes as fp32 - float val_f32 = *(float*)&val_u32; - float val_bf16 = static_cast<bf16>(val_f32); + // Reinterpret u16 bytes as bf16 + bf16 val_bf16 = static_cast<bf16>(val_u16); out.push_back(val_bf16); } return TOSA_OK; @@ -1021,9 +1019,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& for (uint32_t i = 0; i < out_size; i++) { - uint16_t f16_byte0 = in[i * sizeof(int16_t)]; - uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1]; - uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); + uint8_t f16_byte0 = in[i * sizeof(int16_t)]; + uint8_t f16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); // Reinterpret u16 byte as fp16 then convert to fp32 half_float::half val_f16 = *(half_float::half*)&val_u16; |