From a814152b68a286f5bb9ddc095bb1897ec0e3d8ff Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Mon, 29 Apr 2024 23:57:27 +0000 Subject: Use native size of Bfloat16 and Float8 for serialization/deserialization Signed-off-by: Won Jeon Change-Id: I0d2075f90988d4fd1139a11b5c154bdd600bb2cd --- include/numpy_utils.h | 17 ++++++++++++ include/tosa_serialization_handler.h | 12 ++++---- python/serializer/tosa_serializer.py | 42 ++++++++++------------------ src/numpy_utils.cpp | 29 ++++++++++++++++++- src/tosa_serialization_handler.cpp | 54 +++++++++++++++--------------------- 5 files changed, 87 insertions(+), 67 deletions(-) diff --git a/include/numpy_utils.h b/include/numpy_utils.h index 60cf77e..ade2f2d 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -24,8 +24,13 @@ #include #include +#include "cfloat.h" #include "half.hpp" +using bf16 = ct::cfloat; +using fp8e4m3 = ct::cfloat; +using fp8e5m2 = ct::cfloat; + class NumpyUtilities { public: @@ -85,6 +90,18 @@ public: { return "'::value) + { + return "'::value) + { + return "'::value) + { + return "'& in, std::vector& out); - static tosa_err_t ConvertFP8E4M3toU8(const std::vector& in, std::vector& out); - static tosa_err_t ConvertFP8E5M2toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertBF16toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF16toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertF32toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI64toU8(const std::vector& in, std::vector& out); @@ -425,9 +425,9 @@ public: static tosa_err_t ConvertI4toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertBooltoU8(const std::vector& in, std::vector& out); - static tosa_err_t ConvertU8toBF16(const std::vector& in, uint32_t out_size, std::vector& out); - static tosa_err_t ConvertU8toFP8E4M3(const std::vector& in, uint32_t out_size, std::vector& out); - static tosa_err_t ConvertU8toFP8E5M2(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toBF16(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF16(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toF32(const std::vector& in, uint32_t out_size, std::vector& out); diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index c328aaf..7122216 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np -import struct +from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2 from enum import IntEnum, unique from tosa import ( TosaGraph, @@ -392,13 +392,14 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if ( - dtype == DType.FP32 - or dtype == DType.BF16 - or dtype == DType.FP8E4M3 - or dtype == DType.FP8E5M2 - ): + if dtype == DType.FP32: fntype = np.float32 + elif dtype == DType.BF16: + fntype = bfloat16 + elif dtype == DType.FP8E4M3: + fntype = float8_e4m3fn + elif dtype == DType.FP8E5M2: + fntype = float8_e5m2 elif dtype == DType.FP16: fntype = np.float16 else: @@ -943,35 +944,20 @@ class TosaSerializer: np_arr = np.array(data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP32: - # for val in data: - # b = struct.pack("!f", val) - # u8_data.extend([b[3], b[2], b[1], b[0]]) np_arr = np.array(data, dtype=np.float32) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.BF16: for val in data: - # convert val to little endian byte arrays b - b = struct.pack(" [ b[3], b[2], b[1], b[0] ] - # keep only most significant 2 bytes for bf16 - # in little endian ordering - u8_data.extend([b[2], b[3]]) + np_arr = np.array(data, dtype=bfloat16) + u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP8E4M3: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8) + u8_data.append(val_f8) elif dtype == DType.FP8E5M2: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8) + u8_data.append(val_f8) elif dtype == TosaDType.DType: # Serialize DType enum data as uint8 bytes for val in data: diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index e4171d7..7cf5f94 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3 while (isspace(*ptr)) ptr++; + // ml_dtypes writes ' using namespace tosa; -using fp8e4m3 = ct::cfloat; -using fp8e5m2 = ct::cfloat; - TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector* shape, DType dtype, @@ -750,45 +747,41 @@ void TosaSerializationHandler::ForceAlignTensorData(std::vector& buf) } } -tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector& in, std::vector& out) +tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector& in, std::vector& out) { // Note: Converts fp32->bf16 by ignoring the least significant 16 bits out.clear(); for (auto val : in) { - uint32_t* val_u32 = reinterpret_cast(&val); - uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF; - uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF; - // little endian: byte2 followed by byte3 - out.push_back(f32_byte2); - out.push_back(f32_byte3); + uint8_t bf16_byte0 = val.bits() & 0xFF; + uint8_t bf16_byte1 = (val.bits() >> 8) & 0xFF; + out.push_back(bf16_byte0); + out.push_back(bf16_byte1); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector& in, std::vector& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector& in, std::vector& out) { // Note: Converts fp32->FP8E4M3 before converting to unint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector& in, std::vector& out) +tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector& in, std::vector& out) { // Note: Converts fp32->FP8E5M2 before converting to uint8_t out.clear(); for (auto val : in) { - auto f8 = static_cast(val); - uint8_t b8 = f8.bits(); + uint8_t b8 = val.bits(); out.push_back(b8); } ForceAlignTensorData(out); @@ -944,9 +937,8 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector& in return TOSA_OK; } -tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector& in, - uint32_t out_size, - std::vector& out) +tosa_err_t + TosaSerializationHandler::ConvertU8toBF16(const std::vector& in, uint32_t out_size, std::vector& out) { // Note: bf16 values returned in fp32 type out.clear(); @@ -964,17 +956,17 @@ tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector& uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); // Reinterpret u32 bytes as fp32 - float val_f32 = *(float*)&val_u32; - out.push_back(val_f32); + float val_f32 = *(float*)&val_u32; + float val_bf16 = static_cast(val_f32); + out.push_back(val_bf16); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector& in, uint32_t out_size, - std::vector& out) + std::vector& out) { - // Note: FP8E4M3 values returned in fp32 type out.clear(); if (in.size() < out_size * sizeof(int8_t)) { @@ -985,17 +977,16 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector(in[i * sizeof(int8_t)]); - auto f8 = fp8e4m3::from_bits(bits); - float val_f32 = static_cast(f8); - out.push_back(val_f32); + int8_t bits = static_cast(in[i * sizeof(int8_t)]); + auto f8 = fp8e4m3::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector& in, uint32_t out_size, - std::vector& out) + std::vector& out) { // Note: FP8E5M2 values returned in fp32 type out.clear(); @@ -1008,10 +999,9 @@ tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector(in[i * sizeof(int8_t)]); - auto f8 = fp8e5m2::from_bits(bits); - float val_f32 = static_cast(f8); - out.push_back(val_f32); + int8_t bits = static_cast(in[i * sizeof(int8_t)]); + auto f8 = fp8e5m2::from_bits(bits); + out.push_back(f8); } return TOSA_OK; } -- cgit v1.2.1