diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/serializer/tosa_serializer.py | 18 | ||||
-rw-r--r-- | python/tosa/DType.py | 2 |
2 files changed, 17 insertions, 3 deletions
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index c44b225..1aadbff 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -63,6 +63,8 @@ DTypeNames = [ "FP16", "BF16", "SHAPE", + "FP8E4M3", + "FP8E5M2", ] ByteMask = np.uint64(0xFF) @@ -425,7 +427,12 @@ class TosaSerializerTensor: self.shape = shape self.dtype = dtype - if dtype == DType.FP32 or dtype == DType.BF16: + if ( + dtype == DType.FP32 + or dtype == DType.BF16 + or dtype == DType.FP8E4M3 + or dtype == DType.FP8E5M2 + ): fntype = np.float32 elif dtype == DType.FP16: fntype = np.float16 @@ -525,7 +532,12 @@ class TosaSerializerTensor: elif self.dtype == DType.FP16: np_arr = np.array(self.data, dtype=np.float16) u8_data.extend(np_arr.view(np.uint8)) - elif self.dtype == DType.FP32 or self.dtype == DType.BF16: + elif ( + self.dtype == DType.FP32 + or self.dtype == DType.BF16 + or self.dtype == DType.FP8E4M3 + or self.dtype == DType.FP8E5M2 + ): # for val in self.data: # b = struct.pack("!f", val) # u8_data.extend([b[3], b[2], b[1], b[0]]) diff --git a/python/tosa/DType.py b/python/tosa/DType.py index 6df2dcb..e585cb9 100644 --- a/python/tosa/DType.py +++ b/python/tosa/DType.py @@ -16,3 +16,5 @@ class DType(object): FP16 = 10 BF16 = 11 SHAPE = 12 + FP8E4M3 = 13 + FP8E5M2 = 14 |