aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/serializer/tosa_serializer.py48
-rw-r--r--python/tosa/Op.py1
-rw-r--r--python/tosa/ResizeAttribute.py117
3 files changed, 16 insertions, 150 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index c328aaf..c417fce 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -17,7 +17,7 @@ import serializer.tosa_serializer as ts
import json
import flatbuffers
import numpy as np
-import struct
+from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2
from enum import IntEnum, unique
from tosa import (
TosaGraph,
@@ -225,15 +225,12 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.ints.append((a.AddAxis, axis))
- def ResizeAttribute(self, scale, offset, border, mode):
+ def ResizeAttribute(self, mode):
from tosa import ResizeAttribute as a, Attribute
self.utype = Attribute.Attribute().ResizeAttribute
self.optFcns = (a.Start, a.End)
- self.int16vecs.append((a.AddScale, scale))
- self.int16vecs.append((a.AddOffset, offset))
- self.int16vecs.append((a.AddBorder, border))
self.ints.append((a.AddMode, mode))
def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes):
@@ -392,13 +389,14 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if (
- dtype == DType.FP32
- or dtype == DType.BF16
- or dtype == DType.FP8E4M3
- or dtype == DType.FP8E5M2
- ):
+ if dtype == DType.FP32:
fntype = np.float32
+ elif dtype == DType.BF16:
+ fntype = bfloat16
+ elif dtype == DType.FP8E4M3:
+ fntype = float8_e4m3fn
+ elif dtype == DType.FP8E5M2:
+ fntype = float8_e5m2
elif dtype == DType.FP16:
fntype = np.float16
else:
@@ -943,35 +941,19 @@ class TosaSerializer:
np_arr = np.array(data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP32:
- # for val in data:
- # b = struct.pack("!f", val)
- # u8_data.extend([b[3], b[2], b[1], b[0]])
np_arr = np.array(data, dtype=np.float32)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.BF16:
- for val in data:
- # convert val to little endian byte arrays b
- b = struct.pack("<f", val)
- # val => [ b[3], b[2], b[1], b[0] ]
- # keep only most significant 2 bytes for bf16
- # in little endian ordering
- u8_data.extend([b[2], b[3]])
+ np_arr = np.array(data, dtype=bfloat16)
+ u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP8E4M3:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == DType.FP8E5M2:
for val in data:
- # convert val to fp8_bits then to single byte
- f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
- f32_bits = f"{f32_as_int:032b}"
- fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11]
- fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
- u8_data.extend(fp8_bytes)
+ val_f8 = np.array(val).astype(float8_e5m2).view(np.uint8)
+ u8_data.append(val_f8)
elif dtype == TosaDType.DType:
# Serialize DType enum data as uint8 bytes
for val in data:
diff --git a/python/tosa/Op.py b/python/tosa/Op.py
index 35b2b80..021d528 100644
--- a/python/tosa/Op.py
+++ b/python/tosa/Op.py
@@ -84,3 +84,4 @@ class Op(object):
DIV_SHAPE = 78
COS = 79
SIN = 80
+ CAST_STOCHASTIC = 81
diff --git a/python/tosa/ResizeAttribute.py b/python/tosa/ResizeAttribute.py
index 44f7d31..f2a6992 100644
--- a/python/tosa/ResizeAttribute.py
+++ b/python/tosa/ResizeAttribute.py
@@ -29,87 +29,6 @@ class ResizeAttribute(object):
self._tab = flatbuffers.table.Table(buf, pos)
# ResizeAttribute
- def Scale(self, j):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
- if o != 0:
- a = self._tab.Vector(o)
- return self._tab.Get(flatbuffers.number_types.Int16Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2))
- return 0
-
- # ResizeAttribute
- def ScaleAsNumpy(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
- if o != 0:
- return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int16Flags, o)
- return 0
-
- # ResizeAttribute
- def ScaleLength(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
- if o != 0:
- return self._tab.VectorLen(o)
- return 0
-
- # ResizeAttribute
- def ScaleIsNone(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
- return o == 0
-
- # ResizeAttribute
- def Offset(self, j):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
- if o != 0:
- a = self._tab.Vector(o)
- return self._tab.Get(flatbuffers.number_types.Int16Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2))
- return 0
-
- # ResizeAttribute
- def OffsetAsNumpy(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
- if o != 0:
- return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int16Flags, o)
- return 0
-
- # ResizeAttribute
- def OffsetLength(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
- if o != 0:
- return self._tab.VectorLen(o)
- return 0
-
- # ResizeAttribute
- def OffsetIsNone(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
- return o == 0
-
- # ResizeAttribute
- def Border(self, j):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- if o != 0:
- a = self._tab.Vector(o)
- return self._tab.Get(flatbuffers.number_types.Int16Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 2))
- return 0
-
- # ResizeAttribute
- def BorderAsNumpy(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- if o != 0:
- return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int16Flags, o)
- return 0
-
- # ResizeAttribute
- def BorderLength(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- if o != 0:
- return self._tab.VectorLen(o)
- return 0
-
- # ResizeAttribute
- def BorderIsNone(self):
- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
- return o == 0
-
- # ResizeAttribute
def Mode(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
@@ -122,42 +41,6 @@ def ResizeAttributeStart(builder):
def Start(builder):
ResizeAttributeStart(builder)
-def ResizeAttributeAddScale(builder, scale):
- builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(scale), 0)
-
-def AddScale(builder, scale):
- ResizeAttributeAddScale(builder, scale)
-
-def ResizeAttributeStartScaleVector(builder, numElems):
- return builder.StartVector(2, numElems, 2)
-
-def StartScaleVector(builder, numElems):
- return ResizeAttributeStartScaleVector(builder, numElems)
-
-def ResizeAttributeAddOffset(builder, offset):
- builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(offset), 0)
-
-def AddOffset(builder, offset):
- ResizeAttributeAddOffset(builder, offset)
-
-def ResizeAttributeStartOffsetVector(builder, numElems):
- return builder.StartVector(2, numElems, 2)
-
-def StartOffsetVector(builder, numElems):
- return ResizeAttributeStartOffsetVector(builder, numElems)
-
-def ResizeAttributeAddBorder(builder, border):
- builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(border), 0)
-
-def AddBorder(builder, border):
- ResizeAttributeAddBorder(builder, border)
-
-def ResizeAttributeStartBorderVector(builder, numElems):
- return builder.StartVector(2, numElems, 2)
-
-def StartBorderVector(builder, numElems):
- return ResizeAttributeStartBorderVector(builder, numElems)
-
def ResizeAttributeAddMode(builder, mode):
builder.PrependUint32Slot(3, mode, 0)