From 34a627959a61b4eccbeea4400cf9684debb331dc Mon Sep 17 00:00:00 2001 From: James Ward Date: Tue, 18 Oct 2022 17:27:40 +0100 Subject: BF16 support in TOSA serialization Change-Id: I98072019e3dbbf1eab0bc95f74a4546ed82519db Signed-off-by: James Ward --- include/tosa_generated.h | 13 ++++++++----- python/serializer/tosa_serializer.py | 5 +++-- python/tosa/DType.py | 1 + schema/tosa.fbs | 1 + 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 2b9d0ea..f36ed37 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -95,11 +95,12 @@ enum DType : uint32_t { DType_FP32 = 8, DType_UINT16 = 9, DType_FP16 = 10, + DType_BF16 = 11, DType_MIN = DType_UNKNOWN, - DType_MAX = DType_FP16 + DType_MAX = DType_BF16 }; -inline const DType (&EnumValuesDType())[11] { +inline const DType (&EnumValuesDType())[12] { static const DType values[] = { DType_UNKNOWN, DType_BOOL, @@ -111,13 +112,14 @@ inline const DType (&EnumValuesDType())[11] { DType_INT48, DType_FP32, DType_UINT16, - DType_FP16 + DType_FP16, + DType_BF16 }; return values; } inline const char * const *EnumNamesDType() { - static const char * const names[12] = { + static const char * const names[13] = { "UNKNOWN", "BOOL", "UINT8", @@ -129,13 +131,14 @@ inline const char * const *EnumNamesDType() { "FP32", "UINT16", "FP16", + "BF16", nullptr }; return names; } inline const char *EnumNameDType(DType e) { - if (flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_FP16)) return ""; + if (flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_BF16)) return ""; const size_t index = static_cast(e); return EnumNamesDType()[index]; } diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index f4e146c..861ea46 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -59,6 +59,7 @@ DTypeNames = [ "FP32", "UINT16", "FP16", + "BF16", ] ByteMask = np.uint64(0xFF) @@ -378,7 +379,7 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if dtype == DType.FP32: + if dtype == DType.FP32 or dtype == DType.BF16: fntype = np.float32 elif dtype == DType.FP16: fntype = np.float16 @@ -466,7 +467,7 @@ class TosaSerializerTensor: elif self.dtype == DType.FP16: np_arr = np.array(self.data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) - elif self.dtype == DType.FP32: + elif self.dtype == DType.FP32 or self.dtype == DType.BF16: for val in self.data: b = struct.pack("!f", val) u8_data.extend([b[3], b[2], b[1], b[0]]) diff --git a/python/tosa/DType.py b/python/tosa/DType.py index 89669b7..15da2f6 100644 --- a/python/tosa/DType.py +++ b/python/tosa/DType.py @@ -14,3 +14,4 @@ class DType(object): FP32 = 8 UINT16 = 9 FP16 = 10 + BF16 = 11 diff --git a/schema/tosa.fbs b/schema/tosa.fbs index eb76f75..e871562 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -32,6 +32,7 @@ enum DType:uint32 { FP32, UINT16, FP16, + BF16, } enum ResizeMode:uint32 { -- cgit v1.2.1