aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2023-12-29 22:43:11 +0000
committerTai Ly <tai.ly@arm.com>2024-02-06 19:07:37 +0000
commita029f1f02707f40f6990df53fd4f56684490d58f (patch)
tree7054cd1d91d68ec87d1f052efec82dbd7d025009
parent8137a4369acefa4c01f08db95a2b1b290e8dd70a (diff)
downloadserialization_lib-a029f1f02707f40f6990df53fd4f56684490d58f.tar.gz
[serialization_lib] Add support for FP8E4M3 and FP8E5M2
Signed-off-by: Won Jeon <won.jeon@arm.com> Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ife50592890be020b6c6122581eeb2175c8f331e0
-rw-r--r--include/tosa_generated.h16
-rw-r--r--python/serializer/tosa_serializer.py18
-rw-r--r--python/tosa/DType.py2
-rw-r--r--schema/tosa.fbs2
4 files changed, 30 insertions, 8 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 2ecd35a..b5a8bd5 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -116,11 +116,13 @@ enum DType : uint32_t {
DType_FP16 = 10,
DType_BF16 = 11,
DType_SHAPE = 12,
+ DType_FP8E4M3 = 13,
+ DType_FP8E5M2 = 14,
DType_MIN = DType_UNKNOWN,
- DType_MAX = DType_SHAPE
+ DType_MAX = DType_FP8E5M2
};
-inline const DType (&EnumValuesDType())[13] {
+inline const DType (&EnumValuesDType())[15] {
static const DType values[] = {
DType_UNKNOWN,
DType_BOOL,
@@ -134,13 +136,15 @@ inline const DType (&EnumValuesDType())[13] {
DType_UINT16,
DType_FP16,
DType_BF16,
- DType_SHAPE
+ DType_SHAPE,
+ DType_FP8E4M3,
+ DType_FP8E5M2
};
return values;
}
inline const char * const *EnumNamesDType() {
- static const char * const names[14] = {
+ static const char * const names[16] = {
"UNKNOWN",
"BOOL",
"UINT8",
@@ -154,13 +158,15 @@ inline const char * const *EnumNamesDType() {
"FP16",
"BF16",
"SHAPE",
+ "FP8E4M3",
+ "FP8E5M2",
nullptr
};
return names;
}
inline const char *EnumNameDType(DType e) {
- if (::flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_SHAPE)) return "";
+ if (::flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_FP8E5M2)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesDType()[index];
}
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index c44b225..1aadbff 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,6 +63,8 @@ DTypeNames = [
"FP16",
"BF16",
"SHAPE",
+ "FP8E4M3",
+ "FP8E5M2",
]
ByteMask = np.uint64(0xFF)
@@ -425,7 +427,12 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if dtype == DType.FP32 or dtype == DType.BF16:
+ if (
+ dtype == DType.FP32
+ or dtype == DType.BF16
+ or dtype == DType.FP8E4M3
+ or dtype == DType.FP8E5M2
+ ):
fntype = np.float32
elif dtype == DType.FP16:
fntype = np.float16
@@ -525,7 +532,12 @@ class TosaSerializerTensor:
elif self.dtype == DType.FP16:
np_arr = np.array(self.data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
- elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
+ elif (
+ self.dtype == DType.FP32
+ or self.dtype == DType.BF16
+ or self.dtype == DType.FP8E4M3
+ or self.dtype == DType.FP8E5M2
+ ):
# for val in self.data:
# b = struct.pack("!f", val)
# u8_data.extend([b[3], b[2], b[1], b[0]])
diff --git a/python/tosa/DType.py b/python/tosa/DType.py
index 6df2dcb..e585cb9 100644
--- a/python/tosa/DType.py
+++ b/python/tosa/DType.py
@@ -16,3 +16,5 @@ class DType(object):
FP16 = 10
BF16 = 11
SHAPE = 12
+ FP8E4M3 = 13
+ FP8E5M2 = 14
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 171818d..1c2a85e 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -37,6 +37,8 @@ enum DType:uint32 {
FP16,
BF16,
SHAPE,
+ FP8E4M3,
+ FP8E5M2,
}
enum ResizeMode:uint32 {