aboutsummaryrefslogtreecommitdiff
path: root/src/tosa_serialization_handler.cpp
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-04-29 23:57:27 +0000
committerWon Jeon <won.jeon@arm.com>2024-05-03 13:33:29 +0000
commita814152b68a286f5bb9ddc095bb1897ec0e3d8ff (patch)
treec8aa9a42e3d9fdf978e5d366a301b1f8d9716d83 /src/tosa_serialization_handler.cpp
parent3aebe2bd863d6e0cb82171984cd49e5ad516d0db (diff)
downloadserialization_lib-a814152b68a286f5bb9ddc095bb1897ec0e3d8ff.tar.gz
Use native size of Bfloat16 and Float8 for serialization/deserialization
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I0d2075f90988d4fd1139a11b5c154bdd600bb2cd
Diffstat (limited to 'src/tosa_serialization_handler.cpp')
-rw-r--r--src/tosa_serialization_handler.cpp54
1 files changed, 22 insertions, 32 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 0ce6211..76b2198 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -19,9 +19,6 @@
#include <iostream>
using namespace tosa;
-using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
-using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
-
TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
@@ -750,45 +747,41 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
}
}
-tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->bf16 by ignoring the least significant 16 bits
out.clear();
for (auto val : in)
{
- uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val);
- uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF;
- uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF;
- // little endian: byte2 followed by byte3
- out.push_back(f32_byte2);
- out.push_back(f32_byte3);
+ uint8_t bf16_byte0 = val.bits() & 0xFF;
+ uint8_t bf16_byte1 = (val.bits() >> 8) & 0xFF;
+ out.push_back(bf16_byte0);
+ out.push_back(bf16_byte1);
}
ForceAlignTensorData(out);
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->FP8E4M3 before converting to unint8_t
out.clear();
for (auto val : in)
{
- auto f8 = static_cast<fp8e4m3>(val);
- uint8_t b8 = f8.bits();
+ uint8_t b8 = val.bits();
out.push_back(b8);
}
ForceAlignTensorData(out);
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out)
{
// Note: Converts fp32->FP8E5M2 before converting to uint8_t
out.clear();
for (auto val : in)
{
- auto f8 = static_cast<fp8e5m2>(val);
- uint8_t b8 = f8.bits();
+ uint8_t b8 = val.bits();
out.push_back(b8);
}
ForceAlignTensorData(out);
@@ -944,9 +937,8 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in,
- uint32_t out_size,
- std::vector<float>& out)
+tosa_err_t
+ TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out)
{
// Note: bf16 values returned in fp32 type
out.clear();
@@ -964,17 +956,17 @@ tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>&
uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24);
// Reinterpret u32 bytes as fp32
- float val_f32 = *(float*)&val_u32;
- out.push_back(val_f32);
+ float val_f32 = *(float*)&val_u32;
+ float val_bf16 = static_cast<bf16>(val_f32);
+ out.push_back(val_bf16);
}
return TOSA_OK;
}
tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in,
uint32_t out_size,
- std::vector<float>& out)
+ std::vector<fp8e4m3>& out)
{
- // Note: FP8E4M3 values returned in fp32 type
out.clear();
if (in.size() < out_size * sizeof(int8_t))
{
@@ -985,17 +977,16 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_
for (uint32_t i = 0; i < out_size; i++)
{
- int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
- auto f8 = fp8e4m3::from_bits(bits);
- float val_f32 = static_cast<float>(f8);
- out.push_back(val_f32);
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e4m3::from_bits(bits);
+ out.push_back(f8);
}
return TOSA_OK;
}
tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in,
uint32_t out_size,
- std::vector<float>& out)
+ std::vector<fp8e5m2>& out)
{
// Note: FP8E5M2 values returned in fp32 type
out.clear();
@@ -1008,10 +999,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_
for (uint32_t i = 0; i < out_size; i++)
{
- int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
- auto f8 = fp8e5m2::from_bits(bits);
- float val_f32 = static_cast<float>(f8);
- out.push_back(val_f32);
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e5m2::from_bits(bits);
+ out.push_back(f8);
}
return TOSA_OK;
}