From ce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 21 Mar 2024 17:01:14 +0000 Subject: Add conversions of U8 to/from BF16 and FP8 Adds type to PadAttribute and ClampAttribute so their pad_const and max_val/min_val can be deserialized according to type Adds conversion functions of U8 arrays to/from BF16/FP8 values Also, refactor and expose TosaSerializer.convertDataToUint8Vec for converting dtype/data to uint8 list for serialization And modify convertDataToUint8Vec to serialize bf16 values into 2 bytes each, and serialize fp8 values into single bytes each. Signed-off-by: Tai Ly Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f --- python/serializer/tosa_serializer.py | 193 ++++++++++++++++++++--------------- 1 file changed, 110 insertions(+), 83 deletions(-) (limited to 'python/serializer/tosa_serializer.py') diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index e6ab3d0..298907e 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -17,6 +17,7 @@ import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np +import struct from enum import IntEnum, unique from tosa import ( TosaGraph, @@ -204,7 +205,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.bools.append((a.AddLocalBound, local_bound)) self.ints.append((a.AddAccType, acc_type)) - def PadAttribute(self, serializer_builder, pad_const_val_as_bytes): + def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype): from tosa import PadAttribute as a, Attribute self.utype = Attribute.Attribute().PadAttribute @@ -216,6 +217,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): ) self.floats.append((a.AddPadConst, serialized_pad_const_val)) + self.ints.append((a.AddType, dtype)) def AxisAttribute(self, axis): from tosa import AxisAttribute as a, Attribute @@ -236,7 +238,9 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.int16vecs.append((a.AddBorder, border)) self.ints.append((a.AddMode, mode)) - def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes): + def ClampAttribute( + self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype + ): from tosa import ClampAttribute as a, Attribute self.utype = Attribute.Attribute().ClampAttribute @@ -252,6 +256,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.floats.append((a.AddMinVal, serialized_min_val)) self.floats.append((a.AddMaxVal, serialized_max_val)) + self.ints.append((a.AddType, dtype)) def RescaleAttribute( self, @@ -439,87 +444,7 @@ class TosaSerializerTensor: fb_name = builder.CreateString(self.name) fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape) if self.data: - u8_data = list() - # little endianess - if self.dtype == DType.BOOL: - for val in self.data: - val_u8 = np.uint8(val) - u8_data.append(val_u8) - elif self.dtype == DType.INT4: - in_size = len(self.data) - out_size = (in_size + 1) // 2 - for i in range(out_size): - val_0 = self.data[2 * i] - if (2 * i + 1) < in_size: - val_1 = self.data[2 * i + 1] - else: - val_1 = 0 - val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4) - val_u8 = np.uint8(val_i8) - u8_data.append(val_u8) - elif self.dtype == DType.INT8: - for val in self.data: - val_u8 = np.array(val).astype(dtype=np.uint8) - u8_data.append(val_u8) - elif self.dtype == DType.INT16: - for val in self.data: - val_u16 = np.array(val).astype(dtype=np.uint16) - b0 = val_u16 & ByteMask - b1 = (val_u16 >> np.uint16(8)) & ByteMask - u8_data.extend([b0, b1]) - elif self.dtype == DType.INT32: - for val in self.data: - val_u32 = np.array(val).astype(dtype=np.uint32) - b0 = val_u32 & ByteMask - b1 = (val_u32 >> np.uint32(8)) & ByteMask - b2 = (val_u32 >> np.uint32(16)) & ByteMask - b3 = (val_u32 >> np.uint32(24)) & ByteMask - u8_data.extend([b0, b1, b2, b3]) - elif self.dtype == DType.INT48: - for val in self.data: - val_u64 = np.uint64(val) - b0 = val_u64 & ByteMask - b1 = (val_u64 >> np.uint64(8)) & ByteMask - b2 = (val_u64 >> np.uint64(16)) & ByteMask - b3 = (val_u64 >> np.uint64(24)) & ByteMask - b4 = (val_u64 >> np.uint64(32)) & ByteMask - b5 = (val_u64 >> np.uint64(40)) & ByteMask - u8_data.extend([b0, b1, b2, b3, b4, b5]) - elif self.dtype == DType.SHAPE: - for val in self.data: - val_u64 = np.uint64(val) - b0 = val_u64 & ByteMask - b1 = (val_u64 >> np.uint64(8)) & ByteMask - b2 = (val_u64 >> np.uint64(16)) & ByteMask - b3 = (val_u64 >> np.uint64(24)) & ByteMask - b4 = (val_u64 >> np.uint64(32)) & ByteMask - b5 = (val_u64 >> np.uint64(40)) & ByteMask - b6 = (val_u64 >> np.uint64(48)) & ByteMask - b7 = (val_u64 >> np.uint64(56)) & ByteMask - u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7]) - elif self.dtype == DType.FP16: - np_arr = np.array(self.data, dtype=np.float16) - u8_data.extend(np_arr.view(np.uint8)) - elif ( - self.dtype == DType.FP32 - or self.dtype == DType.BF16 - or self.dtype == DType.FP8E4M3 - or self.dtype == DType.FP8E5M2 - ): - # for val in self.data: - # b = struct.pack("!f", val) - # u8_data.extend([b[3], b[2], b[1], b[0]]) - np_arr = np.array(self.data, dtype=np.float32) - u8_data.extend(np_arr.view(np.uint8)) - elif self.dtype == TosaDType.DType: - # Serialize DType enum data as uint8 bytes - for val in self.data: - np_arr = np.array(self.data, dtype=np.uint32) - u8_data.extend(np_arr.view(np.uint8)) - else: - raise Exception( - "unsupported data type {}".format(DTypeNames[self.dtype]) - ) + u8_data = TosaSerializer.convertDataToUint8Vec(self.dtype, self.data) fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data) TosaTensor.Start(builder) @@ -958,3 +883,105 @@ class TosaSerializer: return val else: return [val] + + @staticmethod + def convertDataToUint8Vec(dtype, data): + u8_data = list() + # little endianess + if dtype == DType.BOOL: + for val in data: + val_u8 = np.uint8(val) + u8_data.append(val_u8) + elif dtype == DType.INT4: + in_size = len(data) + out_size = (in_size + 1) // 2 + for i in range(out_size): + val_0 = data[2 * i] + if (2 * i + 1) < in_size: + val_1 = data[2 * i + 1] + else: + val_1 = 0 + val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4) + val_u8 = np.uint8(val_i8) + u8_data.append(val_u8) + elif dtype == DType.INT8: + for val in data: + val_u8 = np.array(val).astype(dtype=np.uint8) + u8_data.append(val_u8) + elif dtype == DType.INT16: + for val in data: + val_u16 = np.array(val).astype(dtype=np.uint16) + b0 = val_u16 & ByteMask + b1 = (val_u16 >> np.uint16(8)) & ByteMask + u8_data.extend([b0, b1]) + elif dtype == DType.INT32: + for val in data: + val_u32 = np.array(val).astype(dtype=np.uint32) + b0 = val_u32 & ByteMask + b1 = (val_u32 >> np.uint32(8)) & ByteMask + b2 = (val_u32 >> np.uint32(16)) & ByteMask + b3 = (val_u32 >> np.uint32(24)) & ByteMask + u8_data.extend([b0, b1, b2, b3]) + elif dtype == DType.INT48: + for val in data: + val_u64 = np.uint64(val) + b0 = val_u64 & ByteMask + b1 = (val_u64 >> np.uint64(8)) & ByteMask + b2 = (val_u64 >> np.uint64(16)) & ByteMask + b3 = (val_u64 >> np.uint64(24)) & ByteMask + b4 = (val_u64 >> np.uint64(32)) & ByteMask + b5 = (val_u64 >> np.uint64(40)) & ByteMask + u8_data.extend([b0, b1, b2, b3, b4, b5]) + elif dtype == DType.SHAPE: + for val in data: + val_u64 = np.uint64(val) + b0 = val_u64 & ByteMask + b1 = (val_u64 >> np.uint64(8)) & ByteMask + b2 = (val_u64 >> np.uint64(16)) & ByteMask + b3 = (val_u64 >> np.uint64(24)) & ByteMask + b4 = (val_u64 >> np.uint64(32)) & ByteMask + b5 = (val_u64 >> np.uint64(40)) & ByteMask + b6 = (val_u64 >> np.uint64(48)) & ByteMask + b7 = (val_u64 >> np.uint64(56)) & ByteMask + u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7]) + elif dtype == DType.FP16: + np_arr = np.array(data, dtype=np.float16) + u8_data.extend(np_arr.view(np.uint8)) + elif dtype == DType.FP32: + # for val in data: + # b = struct.pack("!f", val) + # u8_data.extend([b[3], b[2], b[1], b[0]]) + np_arr = np.array(data, dtype=np.float32) + u8_data.extend(np_arr.view(np.uint8)) + elif dtype == DType.BF16: + for val in data: + # convert val to little endian byte arrays b + b = struct.pack(" [ b[3], b[2], b[1], b[0] ] + # keep only most significant 2 bytes for bf16 + # in little endian ordering + u8_data.extend([b[2], b[3]]) + elif dtype == DType.FP8E4M3: + for val in data: + # convert val to fp8_bits then to single byte + f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] + f32_bits = f"{f32_as_int:032b}" + fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") + u8_data.extend(fp8_bytes) + elif dtype == DType.FP8E5M2: + for val in data: + # convert val to fp8_bits then to single byte + f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] + f32_bits = f"{f32_as_int:032b}" + fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") + u8_data.extend(fp8_bytes) + elif dtype == TosaDType.DType: + # Serialize DType enum data as uint8 bytes + for val in data: + np_arr = np.array(data, dtype=np.uint32) + u8_data.extend(np_arr.view(np.uint8)) + else: + raise Exception("unsupported data type {}".format(DTypeNames[dtype])) + return u8_data -- cgit v1.2.1