aboutsummaryrefslogtreecommitdiff
path: root/python/serializer/tosa_serializer.py
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2023-12-29 22:43:11 +0000
committerTai Ly <tai.ly@arm.com>2024-02-06 19:07:37 +0000
commita029f1f02707f40f6990df53fd4f56684490d58f (patch)
tree7054cd1d91d68ec87d1f052efec82dbd7d025009 /python/serializer/tosa_serializer.py
parent8137a4369acefa4c01f08db95a2b1b290e8dd70a (diff)
downloadserialization_lib-a029f1f02707f40f6990df53fd4f56684490d58f.tar.gz
[serialization_lib] Add support for FP8E4M3 and FP8E5M2
Signed-off-by: Won Jeon <won.jeon@arm.com> Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ife50592890be020b6c6122581eeb2175c8f331e0
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r--python/serializer/tosa_serializer.py18
1 files changed, 15 insertions, 3 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index c44b225..1aadbff 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,6 +63,8 @@ DTypeNames = [
"FP16",
"BF16",
"SHAPE",
+ "FP8E4M3",
+ "FP8E5M2",
]
ByteMask = np.uint64(0xFF)
@@ -425,7 +427,12 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if dtype == DType.FP32 or dtype == DType.BF16:
+ if (
+ dtype == DType.FP32
+ or dtype == DType.BF16
+ or dtype == DType.FP8E4M3
+ or dtype == DType.FP8E5M2
+ ):
fntype = np.float32
elif dtype == DType.FP16:
fntype = np.float16
@@ -525,7 +532,12 @@ class TosaSerializerTensor:
elif self.dtype == DType.FP16:
np_arr = np.array(self.data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
- elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
+ elif (
+ self.dtype == DType.FP32
+ or self.dtype == DType.BF16
+ or self.dtype == DType.FP8E4M3
+ or self.dtype == DType.FP8E5M2
+ ):
# for val in self.data:
# b = struct.pack("!f", val)
# u8_data.extend([b[3], b[2], b[1], b[0]])