diff options
Diffstat (limited to 'python/serializer')
-rw-r--r-- | python/serializer/tosa_serializer.py | 58 |
1 files changed, 19 insertions, 39 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 298907e..34178c5 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts import json import flatbuffers import numpy as np -import struct +from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2 from enum import IntEnum, unique from tosa import ( TosaGraph, @@ -31,8 +31,8 @@ import tosa.DType as TosaDType import tosa.Op as TosaOp # Keep version number in sync with the version default value with schema/tosa.fbs -TOSA_VERSION_MAJOR = 0 -TOSA_VERSION_MINOR = 100 +TOSA_VERSION_MAJOR = 1 +TOSA_VERSION_MINOR = 1 TOSA_VERSION_PATCH = 0 TOSA_VERSION_DRAFT = True TOSA_VERSION = [ @@ -190,7 +190,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.ints.append((a.AddAccType, acc_type)) def TransposeConvAttribute( - self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type + self, outpad, stride, input_zp, weight_zp, local_bound, acc_type ): from tosa import TransposeConvAttribute as a, Attribute @@ -199,13 +199,12 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddOutPad, outpad)) self.intvecs.append((a.AddStride, stride)) - self.intvecs.append((a.AddOutputShape, output_shape)) self.ints.append((a.AddInputZp, input_zp)) self.ints.append((a.AddWeightZp, weight_zp)) self.bools.append((a.AddLocalBound, local_bound)) self.ints.append((a.AddAccType, acc_type)) - def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype): + def PadAttribute(self, serializer_builder, pad_const_val_as_bytes): from tosa import PadAttribute as a, Attribute self.utype = Attribute.Attribute().PadAttribute @@ -217,7 +216,6 @@ class TosaSerializerAttribute(TosaSerializerUnion): ) self.floats.append((a.AddPadConst, serialized_pad_const_val)) - self.ints.append((a.AddType, dtype)) def AxisAttribute(self, axis): from tosa import AxisAttribute as a, Attribute @@ -238,9 +236,7 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.int16vecs.append((a.AddBorder, border)) self.ints.append((a.AddMode, mode)) - def ClampAttribute( - self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype - ): + def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes): from tosa import ClampAttribute as a, Attribute self.utype = Attribute.Attribute().ClampAttribute @@ -256,7 +252,6 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.floats.append((a.AddMinVal, serialized_min_val)) self.floats.append((a.AddMaxVal, serialized_max_val)) - self.ints.append((a.AddType, dtype)) def RescaleAttribute( self, @@ -397,13 +392,14 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if ( - dtype == DType.FP32 - or dtype == DType.BF16 - or dtype == DType.FP8E4M3 - or dtype == DType.FP8E5M2 - ): + if dtype == DType.FP32: fntype = np.float32 + elif dtype == DType.BF16: + fntype = bfloat16 + elif dtype == DType.FP8E4M3: + fntype = float8_e4m3fn + elif dtype == DType.FP8E5M2: + fntype = float8_e5m2 elif dtype == DType.FP16: fntype = np.float16 else: @@ -948,35 +944,19 @@ class TosaSerializer: np_arr = np.array(data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP32: - # for val in data: - # b = struct.pack("!f", val) - # u8_data.extend([b[3], b[2], b[1], b[0]]) np_arr = np.array(data, dtype=np.float32) u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.BF16: - for val in data: - # convert val to little endian byte arrays b - b = struct.pack("<f", val) - # val => [ b[3], b[2], b[1], b[0] ] - # keep only most significant 2 bytes for bf16 - # in little endian ordering - u8_data.extend([b[2], b[3]]) + np_arr = np.array(data, dtype=bfloat16) + u8_data.extend(np_arr.view(np.uint8)) elif dtype == DType.FP8E4M3: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8) + u8_data.append(val_f8) elif dtype == DType.FP8E5M2: for val in data: - # convert val to fp8_bits then to single byte - f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0] - f32_bits = f"{f32_as_int:032b}" - fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] - fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little") - u8_data.extend(fp8_bytes) + val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8) + u8_data.append(val_f8) elif dtype == TosaDType.DType: # Serialize DType enum data as uint8 bytes for val in data: |