From a814152b68a286f5bb9ddc095bb1897ec0e3d8ff Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Mon, 29 Apr 2024 23:57:27 +0000 Subject: Use native size of Bfloat16 and Float8 for serialization/deserialization Signed-off-by: Won Jeon Change-Id: I0d2075f90988d4fd1139a11b5c154bdd600bb2cd --- python/serializer/tosa_serializer.py | 42 ++++++++++++------------------------ 1 file changed, 14 insertions(+), 28 deletions(-) (limited to 'python') diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index c328aaf..7122216 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np -import struct +from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2 from enum import IntEnum, unique from tosa import ( TosaGraph, @@ -392,13 +392,14 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if ( - dtype == DType.FP32 - or dtype == DType.BF16 - or dtype == DType.FP8E4M3 - or dtype == DType.FP8E5M2 - ): + if dtype == DType.FP32: fntype = np.float32 + elif dtype == DType.BF16: + fntype = bfloat16 + elif dtype == DType.FP8E4M3: + fntype = float8_e4m3fn + elif dtype == DType.FP8E5M2: + fntype = float8_e5m2 elif dtype == DType.FP16: fntype = np.float16 else: @@ -943,35 +944,20 @@ class TosaSerializer: np_arr = np.array(data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP32: - # for val in data: - # b = struct.pack("!f", val) - # u8_data.extend([b[3], b[2], b[1], b[0]]) np_arr = np.array(data, dtype=np.float32) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.BF16: for val in data: - # convert val to little endian byte arrays b - b = struct.pack(" [ b[3], b[2], b[1], b[0] ] - # keep only most significant 2 bytes for bf16 - # in little endian ordering - u8_data.extend([b[2], b[3]]) + np_arr = np.array(data, dtype=bfloat16) + u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP8E4M3: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8) + u8_data.append(val_f8) elif dtype == DType.FP8E5M2: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8) + u8_data.append(val_f8) elif dtype == TosaDType.DType: # Serialize DType enum data as uint8 bytes for val in data: -- cgit v1.2.1