aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-10-18 17:27:40 +0100
committerJames Ward <james.ward@arm.com>2022-10-26 11:57:21 +0100
commit34a627959a61b4eccbeea4400cf9684debb331dc (patch)
tree7b6be68e49010f9a621c8e8f67f55163a534fe69
parente1072a9ed871fd474e7b09b7a74ae7be5f0a6f78 (diff)
downloadserialization_lib-34a627959a61b4eccbeea4400cf9684debb331dc.tar.gz
BF16 support in TOSA serialization
Change-Id: I98072019e3dbbf1eab0bc95f74a4546ed82519db Signed-off-by: James Ward <james.ward@arm.com>
-rw-r--r--include/tosa_generated.h13
-rw-r--r--python/serializer/tosa_serializer.py5
-rw-r--r--python/tosa/DType.py1
-rw-r--r--schema/tosa.fbs1
4 files changed, 13 insertions, 7 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 2b9d0ea..f36ed37 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -95,11 +95,12 @@ enum DType : uint32_t {
DType_FP32 = 8,
DType_UINT16 = 9,
DType_FP16 = 10,
+ DType_BF16 = 11,
DType_MIN = DType_UNKNOWN,
- DType_MAX = DType_FP16
+ DType_MAX = DType_BF16
};
-inline const DType (&EnumValuesDType())[11] {
+inline const DType (&EnumValuesDType())[12] {
static const DType values[] = {
DType_UNKNOWN,
DType_BOOL,
@@ -111,13 +112,14 @@ inline const DType (&EnumValuesDType())[11] {
DType_INT48,
DType_FP32,
DType_UINT16,
- DType_FP16
+ DType_FP16,
+ DType_BF16
};
return values;
}
inline const char * const *EnumNamesDType() {
- static const char * const names[12] = {
+ static const char * const names[13] = {
"UNKNOWN",
"BOOL",
"UINT8",
@@ -129,13 +131,14 @@ inline const char * const *EnumNamesDType() {
"FP32",
"UINT16",
"FP16",
+ "BF16",
nullptr
};
return names;
}
inline const char *EnumNameDType(DType e) {
- if (flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_FP16)) return "";
+ if (flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_BF16)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesDType()[index];
}
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index f4e146c..861ea46 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -59,6 +59,7 @@ DTypeNames = [
"FP32",
"UINT16",
"FP16",
+ "BF16",
]
ByteMask = np.uint64(0xFF)
@@ -378,7 +379,7 @@ class TosaSerializerTensor:
self.shape = shape
self.dtype = dtype
- if dtype == DType.FP32:
+ if dtype == DType.FP32 or dtype == DType.BF16:
fntype = np.float32
elif dtype == DType.FP16:
fntype = np.float16
@@ -466,7 +467,7 @@ class TosaSerializerTensor:
elif self.dtype == DType.FP16:
np_arr = np.array(self.data, dtype=np.float16)
u8_data.extend(np_arr.view(np.uint8))
- elif self.dtype == DType.FP32:
+ elif self.dtype == DType.FP32 or self.dtype == DType.BF16:
for val in self.data:
b = struct.pack("!f", val)
u8_data.extend([b[3], b[2], b[1], b[0]])
diff --git a/python/tosa/DType.py b/python/tosa/DType.py
index 89669b7..15da2f6 100644
--- a/python/tosa/DType.py
+++ b/python/tosa/DType.py
@@ -14,3 +14,4 @@ class DType(object):
FP32 = 8
UINT16 = 9
FP16 = 10
+ BF16 = 11
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index eb76f75..e871562 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -32,6 +32,7 @@ enum DType:uint32 {
FP32,
UINT16,
FP16,
+ BF16,
}
enum ResizeMode:uint32 {