From ce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 21 Mar 2024 17:01:14 +0000 Subject: Add conversions of U8 to/from BF16 and FP8 Adds type to PadAttribute and ClampAttribute so their pad_const and max_val/min_val can be deserialized according to type Adds conversion functions of U8 arrays to/from BF16/FP8 values Also, refactor and expose TosaSerializer.convertDataToUint8Vec for converting dtype/data to uint8 list for serialization And modify convertDataToUint8Vec to serialize bf16 values into 2 bytes each, and serialize fp8 values into single bytes each. Signed-off-by: Tai Ly Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f --- include/tosa_serialization_handler.h | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'include/tosa_serialization_handler.h') diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 5c53f57..f5f9e58 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -18,6 +18,7 @@ #include "attribute.h" #include "flatbuffers/idl.h" #include "flatbuffers/util.h" +#include "float_utils.h" #include "numpy_utils.h" #include "tosa_generated.h" #include @@ -411,6 +412,9 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. + static tosa_err_t ConvertBF16toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF16toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF32toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI64toU8(const std::vector& in, std::vector& out); @@ -421,6 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertBooltoU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertU8toBF16(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF16(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF32(const std::vector& in, uint32_t out_size, std::vector& out); -- cgit v1.2.1