diff options
author | Won Jeon <won.jeon@arm.com> | 2023-12-29 22:43:11 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-02-06 19:07:37 +0000 |
commit | a029f1f02707f40f6990df53fd4f56684490d58f (patch) | |
tree | 7054cd1d91d68ec87d1f052efec82dbd7d025009 /python/serializer/tosa_serializer.py | |
parent | 8137a4369acefa4c01f08db95a2b1b290e8dd70a (diff) | |
download | serialization_lib-a029f1f02707f40f6990df53fd4f56684490d58f.tar.gz |
[serialization_lib] Add support for FP8E4M3 and FP8E5M2
Signed-off-by: Won Jeon <won.jeon@arm.com>
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: Ife50592890be020b6c6122581eeb2175c8f331e0
Diffstat (limited to 'python/serializer/tosa_serializer.py')
-rw-r--r-- | python/serializer/tosa_serializer.py | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index c44b225..1aadbff 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -63,6 +63,8 @@ DTypeNames = [ "FP16", "BF16", "SHAPE", + "FP8E4M3", + "FP8E5M2", ] ByteMask = np.uint64(0xFF) @@ -425,7 +427,12 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if dtype == DType.FP32 or dtype == DType.BF16: + if ( + dtype == DType.FP32 + or dtype == DType.BF16 + or dtype == DType.FP8E4M3 + or dtype == DType.FP8E5M2 + ): fntype = np.float32 elif dtype == DType.FP16: fntype = np.float16 @@ -525,7 +532,12 @@ class TosaSerializerTensor: elif self.dtype == DType.FP16: np_arr = np.array(self.data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) - elif self.dtype == DType.FP32 or self.dtype == DType.BF16: + elif ( + self.dtype == DType.FP32 + or self.dtype == DType.BF16 + or self.dtype == DType.FP8E4M3 + or self.dtype == DType.FP8E5M2 + ): # for val in self.data: # b = struct.pack("!f", val) # u8_data.extend([b[3], b[2], b[1], b[0]]) |