aboutsummaryrefslogtreecommitdiff
path: root/python/serializer
diff options
context:
space:
mode:
Diffstat (limited to 'python/serializer')
-rw-r--r--python/serializer/tosa_serializer.py58
1 files changed, 19 insertions, 39 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 298907e..34178c5 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
-import struct
+from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2
from enum import IntEnum, unique
from tosa import (
TosaGraph,
@@ -31,8 +31,8 @@ import tosa.DType as TosaDType
import tosa.Op as TosaOp
# Keep version number in sync with the version default value with schema/tosa.fbs
-TOSA_VERSION_MAJOR = 0
-TOSA_VERSION_MINOR = 100
+TOSA_VERSION_MAJOR = 1
+TOSA_VERSION_MINOR = 1
TOSA_VERSION_PATCH = 0
TOSA_VERSION_DRAFT = True
TOSA_VERSION = [
@@ -190,7 +190,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddAccType, acc_type))
def TransposeConvAttribute(
- self, outpad, stride, output_shape, input_zp, weight_zp, local_bound, acc_type
+ self, outpad, stride, input_zp, weight_zp, local_bound, acc_type
):
from tosa import TransposeConvAttribute as a, Attribute
@@ -199,13 +199,12 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddOutPad, outpad))
self.intvecs.append((a.AddStride, stride))
- self.intvecs.append((a.AddOutputShape, output_shape))
self.ints.append((a.AddInputZp, input_zp))
self.ints.append((a.AddWeightZp, weight_zp))
self.bools.append((a.AddLocalBound, local_bound))
self.ints.append((a.AddAccType, acc_type))
- def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype):
+ def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
from tosa import PadAttribute as a, Attribute
self.utype = Attribute.Attribute().PadAttribute
@@ -217,7 +216,6 @@ class TosaSerializerAttribute(TosaSerializerUnion):
)
self.floats.append((a.AddPadConst, serialized_pad_const_val))
- self.ints.append((a.AddType, dtype))
def AxisAttribute(self, axis):
from tosa import AxisAttribute as a, Attribute
@@ -238,9 +236,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.int16vecs.append((a.AddBorder, border))
self.ints.append((a.AddMode, mode))
- def ClampAttribute(
- self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype
- ):
+ def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes):
from tosa import ClampAttribute as a, Attribute
self.utype = Attribute.Attribute().ClampAttribute
@@ -256,7 +252,6 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.floats.append((a.AddMinVal, serialized_min_val))
self.floats.append((a.AddMaxVal, serialized_max_val))
- self.ints.append((a.AddType, dtype))
def RescaleAttribute(
self,
@@ -397,13 +392,14 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if (
- dtype == DType.FP32
- or dtype == DType.BF16
- or dtype == DType.FP8E4M3
- or dtype == DType.FP8E5M2
- ):
+ if dtype == DType.FP32:
fntype = np.float32
+ elif dtype == DType.BF16:
+ fntype = bfloat16
+ elif dtype == DType.FP8E4M3:
+ fntype = float8_e4m3fn
+ elif dtype == DType.FP8E5M2:
+ fntype = float8_e5m2
elif dtype == DType.FP16:
fntype = np.float16
else:
@@ -948,35 +944,19 @@ class TosaSerializer:
np_arr = np.array(data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP32:
- # for val in data:
- # b = struct.pack("!f", val)
- # u8_data.extend([b[3], b[2], b[1], b[0]])
np_arr = np.array(data, dtype=np.float32)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.BF16:
- for val in data:
- # convert val to little endian byte arrays b
- b = struct.pack("<f", val)
- # val => [ b[3], b[2], b[1], b[0] ]
- # keep only most significant 2 bytes for bf16
- # in little endian ordering
- u8_data.extend([b[2], b[3]])
+ np_arr = np.array(data, dtype=bfloat16)
+ u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP8E4M3:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == DType.FP8E5M2:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == TosaDType.DType:
# Serialize DType enum data as uint8 bytes
for val in data: