From a029f1f02707f40f6990df53fd4f56684490d58f Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Fri, 29 Dec 2023 22:43:11 +0000 Subject: [serialization_lib] Add support for FP8E4M3 and FP8E5M2 Signed-off-by: Won Jeon Signed-off-by: Tai Ly Change-Id: Ife50592890be020b6c6122581eeb2175c8f331e0 --- include/tosa_generated.h | 16 +++++++++++----- python/serializer/tosa_serializer.py | 18 +++++++++++++++--- python/tosa/DType.py | 2 ++ schema/tosa.fbs | 2 ++ 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 2ecd35a..b5a8bd5 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -116,11 +116,13 @@ enum DType : uint32_t { DType_FP16 = 10, DType_BF16 = 11, DType_SHAPE = 12, + DType_FP8E4M3 = 13, + DType_FP8E5M2 = 14, DType_MIN = DType_UNKNOWN, - DType_MAX = DType_SHAPE + DType_MAX = DType_FP8E5M2 }; -inline const DType (&EnumValuesDType())[13] { +inline const DType (&EnumValuesDType())[15] { static const DType values[] = { DType_UNKNOWN, DType_BOOL, @@ -134,13 +136,15 @@ inline const DType (&EnumValuesDType())[13] { DType_UINT16, DType_FP16, DType_BF16, - DType_SHAPE + DType_SHAPE, + DType_FP8E4M3, + DType_FP8E5M2 }; return values; } inline const char * const *EnumNamesDType() { - static const char * const names[14] = { + static const char * const names[16] = { "UNKNOWN", "BOOL", "UINT8", @@ -154,13 +158,15 @@ inline const char * const *EnumNamesDType() { "FP16", "BF16", "SHAPE", + "FP8E4M3", + "FP8E5M2", nullptr }; return names; } inline const char *EnumNameDType(DType e) { - if (::flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_SHAPE)) return ""; + if (::flatbuffers::IsOutRange(e, DType_UNKNOWN, DType_FP8E5M2)) return ""; const size_t index = static_cast(e); return EnumNamesDType()[index]; } diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index c44b225..1aadbff 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -63,6 +63,8 @@ DTypeNames = [ "FP16", "BF16", "SHAPE", + "FP8E4M3", + "FP8E5M2", ] ByteMask = np.uint64(0xFF) @@ -425,7 +427,12 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if dtype == DType.FP32 or dtype == DType.BF16: + if ( + dtype == DType.FP32 + or dtype == DType.BF16 + or dtype == DType.FP8E4M3 + or dtype == DType.FP8E5M2 + ): fntype = np.float32 elif dtype == DType.FP16: fntype = np.float16 @@ -525,7 +532,12 @@ class TosaSerializerTensor: elif self.dtype == DType.FP16: np_arr = np.array(self.data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) - elif self.dtype == DType.FP32 or self.dtype == DType.BF16: + elif ( + self.dtype == DType.FP32 + or self.dtype == DType.BF16 + or self.dtype == DType.FP8E4M3 + or self.dtype == DType.FP8E5M2 + ): # for val in self.data: # b = struct.pack("!f", val) # u8_data.extend([b[3], b[2], b[1], b[0]]) diff --git a/python/tosa/DType.py b/python/tosa/DType.py index 6df2dcb..e585cb9 100644 --- a/python/tosa/DType.py +++ b/python/tosa/DType.py @@ -16,3 +16,5 @@ class DType(object): FP16 = 10 BF16 = 11 SHAPE = 12 + FP8E4M3 = 13 + FP8E5M2 = 14 diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 171818d..1c2a85e 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -37,6 +37,8 @@ enum DType:uint32 { FP16, BF16, SHAPE, + FP8E4M3, + FP8E5M2, } enum ResizeMode:uint32 { -- cgit v1.2.1