aboutsummaryrefslogtreecommitdiff
path: root/python/serializer/tosa_serializer.py
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-21 17:01:14 +0000
committerTai Ly <tai.ly@arm.com>2024-04-08 22:18:34 +0000
commitce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a (patch)
tree68faf6d7b1c676705a022b32d8aa7950db03ab5e /python/serializer/tosa_serializer.py
parent8f9e2842ce7d25645233ad4f6fa406be982346ae (diff)
downloadserialization_lib-ce911a2f1d9cd678fb9fe82a40c86ad0c6772f5a.tar.gz
Add conversions of U8 to/from BF16 and FP8
Adds type to PadAttribute and ClampAttribute so their pad_const and max_val/min_val can be deserialized according to type Adds conversion functions of U8 arrays to/from BF16/FP8 values Also, refactor and expose TosaSerializer.convertDataToUint8Vec for converting dtype/data to uint8 list for serialization And modify convertDataToUint8Vec to serialize bf16 values into 2 bytes each, and serialize fp8 values into single bytes each. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r--python/serializer/tosa_serializer.py193
1 files changed, 110 insertions, 83 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index e6ab3d0..298907e 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -17,6 +17,7 @@ import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
+import struct
from enum import IntEnum, unique
from tosa import (
TosaGraph,
@@ -204,7 +205,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.bools.append((a.AddLocalBound, local_bound))
self.ints.append((a.AddAccType, acc_type))
- def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
+ def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype):
from tosa import PadAttribute as a, Attribute
self.utype = Attribute.Attribute().PadAttribute
@@ -216,6 +217,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
)
self.floats.append((a.AddPadConst, serialized_pad_const_val))
+ self.ints.append((a.AddType, dtype))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -236,7 +238,9 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.int16vecs.append((a.AddBorder, border))
self.ints.append((a.AddMode, mode))
- def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes):
+ def ClampAttribute(
+ self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype
+ ):
from tosa import ClampAttribute as a, Attribute
self.utype = Attribute.Attribute().ClampAttribute
@@ -252,6 +256,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.floats.append((a.AddMinVal, serialized_min_val))
self.floats.append((a.AddMaxVal, serialized_max_val))
+ self.ints.append((a.AddType, dtype))
def RescaleAttribute(
self,
@@ -439,87 +444,7 @@ class TosaSerializerTensor:
fb_name = builder.CreateString(self.name)
fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
if self.data:
- u8_data = list()
- # little endianess
- if self.dtype == DType.BOOL:
- for val in self.data:
- val_u8 = np.uint8(val)
- u8_data.append(val_u8)
- elif self.dtype == DType.INT4:
- in_size = len(self.data)
- out_size = (in_size + 1) // 2
- for i in range(out_size):
- val_0 = self.data[2 * i]
- if (2 * i + 1) < in_size:
- val_1 = self.data[2 * i + 1]
- else:
- val_1 = 0
- val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
- val_u8 = np.uint8(val_i8)
- u8_data.append(val_u8)
- elif self.dtype == DType.INT8:
- for val in self.data:
- val_u8 = np.array(val).astype(dtype=np.uint8)
- u8_data.append(val_u8)
- elif self.dtype == DType.INT16:
- for val in self.data:
- val_u16 = np.array(val).astype(dtype=np.uint16)
- b0 = val_u16 & ByteMask
- b1 = (val_u16 >> np.uint16(8)) & ByteMask
- u8_data.extend([b0, b1])
- elif self.dtype == DType.INT32:
- for val in self.data:
- val_u32 = np.array(val).astype(dtype=np.uint32)
- b0 = val_u32 & ByteMask
- b1 = (val_u32 >> np.uint32(8)) & ByteMask
- b2 = (val_u32 >> np.uint32(16)) & ByteMask
- b3 = (val_u32 >> np.uint32(24)) & ByteMask
- u8_data.extend([b0, b1, b2, b3])
- elif self.dtype == DType.INT48:
- for val in self.data:
- val_u64 = np.uint64(val)
- b0 = val_u64 & ByteMask
- b1 = (val_u64 >> np.uint64(8)) & ByteMask
- b2 = (val_u64 >> np.uint64(16)) & ByteMask
- b3 = (val_u64 >> np.uint64(24)) & ByteMask
- b4 = (val_u64 >> np.uint64(32)) & ByteMask
- b5 = (val_u64 >> np.uint64(40)) & ByteMask
- u8_data.extend([b0, b1, b2, b3, b4, b5])
- elif self.dtype == DType.SHAPE:
- for val in self.data:
- val_u64 = np.uint64(val)
- b0 = val_u64 & ByteMask
- b1 = (val_u64 >> np.uint64(8)) & ByteMask
- b2 = (val_u64 >> np.uint64(16)) & ByteMask
- b3 = (val_u64 >> np.uint64(24)) & ByteMask
- b4 = (val_u64 >> np.uint64(32)) & ByteMask
- b5 = (val_u64 >> np.uint64(40)) & ByteMask
- b6 = (val_u64 >> np.uint64(48)) & ByteMask
- b7 = (val_u64 >> np.uint64(56)) & ByteMask
- u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7])
- elif self.dtype == DType.FP16:
- np_arr = np.array(self.data, dtype=np.float16)
- u8_data.extend(np_arr.view(np.uint8))
- elif (
- self.dtype == DType.FP32
- or self.dtype == DType.BF16
- or self.dtype == DType.FP8E4M3
- or self.dtype == DType.FP8E5M2
- ):
- # for val in self.data:
- # b = struct.pack("!f", val)
- # u8_data.extend([b[3], b[2], b[1], b[0]])
- np_arr = np.array(self.data, dtype=np.float32)
- u8_data.extend(np_arr.view(np.uint8))
- elif self.dtype == TosaDType.DType:
- # Serialize DType enum data as uint8 bytes
- for val in self.data:
- np_arr = np.array(self.data, dtype=np.uint32)
- u8_data.extend(np_arr.view(np.uint8))
- else:
- raise Exception(
- "unsupported data type {}".format(DTypeNames[self.dtype])
- )
+ u8_data = TosaSerializer.convertDataToUint8Vec(self.dtype, self.data)
fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
TosaTensor.Start(builder)
@@ -958,3 +883,105 @@ class TosaSerializer:
return val
else:
return [val]
+
+ @staticmethod
+ def convertDataToUint8Vec(dtype, data):
+ u8_data = list()
+ # little endianess
+ if dtype == DType.BOOL:
+ for val in data:
+ val_u8 = np.uint8(val)
+ u8_data.append(val_u8)
+ elif dtype == DType.INT4:
+ in_size = len(data)
+ out_size = (in_size + 1) // 2
+ for i in range(out_size):
+ val_0 = data[2 * i]
+ if (2 * i + 1) < in_size:
+ val_1 = data[2 * i + 1]
+ else:
+ val_1 = 0
+ val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
+ val_u8 = np.uint8(val_i8)
+ u8_data.append(val_u8)
+ elif dtype == DType.INT8:
+ for val in data:
+ val_u8 = np.array(val).astype(dtype=np.uint8)
+ u8_data.append(val_u8)
+ elif dtype == DType.INT16:
+ for val in data:
+ val_u16 = np.array(val).astype(dtype=np.uint16)
+ b0 = val_u16 & ByteMask
+ b1 = (val_u16 >> np.uint16(8)) & ByteMask
+ u8_data.extend([b0, b1])
+ elif dtype == DType.INT32:
+ for val in data:
+ val_u32 = np.array(val).astype(dtype=np.uint32)
+ b0 = val_u32 & ByteMask
+ b1 = (val_u32 >> np.uint32(8)) & ByteMask
+ b2 = (val_u32 >> np.uint32(16)) & ByteMask
+ b3 = (val_u32 >> np.uint32(24)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3])
+ elif dtype == DType.INT48:
+ for val in data:
+ val_u64 = np.uint64(val)
+ b0 = val_u64 & ByteMask
+ b1 = (val_u64 >> np.uint64(8)) & ByteMask
+ b2 = (val_u64 >> np.uint64(16)) & ByteMask
+ b3 = (val_u64 >> np.uint64(24)) & ByteMask
+ b4 = (val_u64 >> np.uint64(32)) & ByteMask
+ b5 = (val_u64 >> np.uint64(40)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3, b4, b5])
+ elif dtype == DType.SHAPE:
+ for val in data:
+ val_u64 = np.uint64(val)
+ b0 = val_u64 & ByteMask
+ b1 = (val_u64 >> np.uint64(8)) & ByteMask
+ b2 = (val_u64 >> np.uint64(16)) & ByteMask
+ b3 = (val_u64 >> np.uint64(24)) & ByteMask
+ b4 = (val_u64 >> np.uint64(32)) & ByteMask
+ b5 = (val_u64 >> np.uint64(40)) & ByteMask
+ b6 = (val_u64 >> np.uint64(48)) & ByteMask
+ b7 = (val_u64 >> np.uint64(56)) & ByteMask
+ u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7])
+ elif dtype == DType.FP16:
+ np_arr = np.array(data, dtype=np.float16)
+ u8_data.extend(np_arr.view(np.uint8))
+ elif dtype == DType.FP32:
+ # for val in data:
+ # b = struct.pack("!f", val)
+ # u8_data.extend([b[3], b[2], b[1], b[0]])
+ np_arr = np.array(data, dtype=np.float32)
+ u8_data.extend(np_arr.view(np.uint8))
+ elif dtype == DType.BF16:
+ for val in data:
+ # convert val to little endian byte arrays b
+ b = struct.pack("<f", val)
+ # val => [ b[3], b[2], b[1], b[0] ]
+ # keep only most significant 2 bytes for bf16
+ # in little endian ordering
+ u8_data.extend([b[2], b[3]])
+ elif dtype == DType.FP8E4M3:
+ for val in data:
+ # convert val to fp8_bits then to single byte
+ f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
+ f32_bits = f"{f32_as_int:032b}"
+ fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12]
+ fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
+ u8_data.extend(fp8_bytes)
+ elif dtype == DType.FP8E5M2:
+ for val in data:
+ # convert val to fp8_bits then to single byte
+ f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
+ f32_bits = f"{f32_as_int:032b}"
+ fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11]
+ fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
+ u8_data.extend(fp8_bytes)
+ elif dtype == TosaDType.DType:
+ # Serialize DType enum data as uint8 bytes
+ for val in data:
+ np_arr = np.array(data, dtype=np.uint32)
+ u8_data.extend(np_arr.view(np.uint8))
+ else:
+ raise Exception("unsupported data type {}".format(DTypeNames[dtype]))
+ return u8_data