diff options
author | Won Jeon <won.jeon@arm.com> | 2024-04-29 23:57:27 +0000 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2024-05-03 13:33:29 +0000 |
commit | a814152b68a286f5bb9ddc095bb1897ec0e3d8ff (patch) | |
tree | c8aa9a42e3d9fdf978e5d366a301b1f8d9716d83 /include | |
parent | 3aebe2bd863d6e0cb82171984cd49e5ad516d0db (diff) | |
download | serialization_lib-a814152b68a286f5bb9ddc095bb1897ec0e3d8ff.tar.gz |
Use native size of Bfloat16 and Float8 for serialization/deserialization
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I0d2075f90988d4fd1139a11b5c154bdd600bb2cd
Diffstat (limited to 'include')
-rw-r--r-- | include/numpy_utils.h | 17 | ||||
-rw-r--r-- | include/tosa_serialization_handler.h | 12 |
2 files changed, 23 insertions, 6 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h index 60cf77e..ade2f2d 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -24,8 +24,13 @@ #include <cstring> #include <vector> +#include "cfloat.h" #include "half.hpp" +using bf16 = ct::cfloat<int16_t, 8, true, true, true>; +using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; +using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; + class NumpyUtilities { public: @@ -85,6 +90,18 @@ public: { return "'<f2'"; } + if (std::is_same<T, bf16>::value) + { + return "'<V2'"; + } + if (std::is_same<T, fp8e4m3>::value) + { + return "'<V1'"; + } + if (std::is_same<T, fp8e5m2>::value) + { + return "'<f1'"; + } assert(false && "unsupported Dtype"); }; diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 139a476..c09a47d 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -412,9 +412,9 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. - static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertBF16toU8(const std::vector<bf16>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector<fp8e4m3>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector<fp8e5m2>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out); @@ -425,9 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out); - static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); - static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bf16>& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e4m3>& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<fp8e5m2>& out); static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out); static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); |