aboutsummaryrefslogtreecommitdiff
path: root/python/serializer
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-05-09 06:00:31 +0000
committerWon Jeon <won.jeon@arm.com>2024-05-09 18:21:17 +0000
commit07098d64498723e889e8f8deaad4952020fa9450 (patch)
tree94f4bd6a070c73bbd629118fe6bfd8390bdf7213 /python/serializer
parentb386815fcf36092be832281821af7ad9f2119e07 (diff)
downloadserialization_lib-07098d64498723e889e8f8deaad4952020fa9450.tar.gz
Fix Bfloat16 data conversion for serialization
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I52f6fea3e8b4cd5ff0886ccfa12396a680558670
Diffstat (limited to 'python/serializer')
-rw-r--r--python/serializer/tosa_serializer.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index 7122216..34178c5 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -947,9 +947,8 @@ class TosaSerializer:
np_arr = np.array(data, dtype=np.float32)
u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.BF16:
- for val in data:
- np_arr = np.array(data, dtype=bfloat16)
- u8_data.extend(np_arr.view(np.uint8))
+ np_arr = np.array(data, dtype=bfloat16)
+ u8_data.extend(np_arr.view(np.uint8))
elif dtype == DType.FP8E4M3:
for val in data:
val_f8 = np.array(val).astype(float8_e4m3fn).view(np.uint8)