aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/serializer/tosa_serializer.py5
-rw-r--r--python/tosa/DType.py1
2 files changed, 4 insertions, 2 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index f4e146c..861ea46 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -59,6 +59,7 @@ DTypeNames = [
"FP32",
"UINT16",
"FP16",
+ "BF16",
]
ByteMask = np.uint64(0xFF)
@@ -378,7 +379,7 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if dtype == DType.FP32:
+ if dtype == DType.FP32 or dtype == DType.BF16:
fntype = np.float32
elif dtype == DType.FP16:
fntype = np.float16
@@ -466,7 +467,7 @@ class TosaSerializerTensor:
elif self.dtype == DType.FP16:
np_arr = np.array(self.data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
- elif self.dtype == DType.FP32:
+ elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
for val in self.data:
b = struct.pack("!f", val)
u8_data.extend([b[3], b[2], b[1], b[0]])
diff --git a/python/tosa/DType.py b/python/tosa/DType.py
index 89669b7..15da2f6 100644
--- a/python/tosa/DType.py
+++ b/python/tosa/DType.py
@@ -14,3 +14,4 @@ class DType(object):
FP32 = 8
UINT16 = 9
FP16 = 10
+ BF16 = 11