diff options
author | Won Jeon <won.jeon@arm.com> | 2024-04-29 23:57:27 +0000 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2024-05-03 13:33:29 +0000 |
commit | a814152b68a286f5bb9ddc095bb1897ec0e3d8ff (patch) | |
tree | c8aa9a42e3d9fdf978e5d366a301b1f8d9716d83 /src/numpy_utils.cpp | |
parent | 3aebe2bd863d6e0cb82171984cd49e5ad516d0db (diff) | |
download | serialization_lib-a814152b68a286f5bb9ddc095bb1897ec0e3d8ff.tar.gz |
Use native size of Bfloat16 and Float8 for serialization/deserialization
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I0d2075f90988d4fd1139a11b5c154bdd600bb2cd
Diffstat (limited to 'src/numpy_utils.cpp')
-rw-r--r-- | src/numpy_utils.cpp | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index e4171d7..7cf5f94 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3 while (isspace(*ptr)) ptr++; + // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but + // default NumPy does not understand this notation, which causes trouble + // when other code tries to open this file. + // To avoid this, '|u1' notation is used when the file is written, and the uint8 + // data is viewed as float8_e5m2 later when the file is read. + if (!strcmp(dtype_str, "'<f1'")) + dtype_str = "'|u1'"; + if (strcmp(ptr, dtype_str)) { return FILE_TYPE_MISMATCH; @@ -430,6 +438,13 @@ NumpyUtilities::NPError memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1); headerPos += sizeof(NUMPY_HEADER_STR) - 1; + // NumPy does not understand float8_e5m2, so change it to uint8 type, so that + // Python can read .npy files. + if (!strcmp(dtype_str, "'<f1'")) + { + dtype_str = "'|u1'"; + } + // Output the format dictionary // Hard-coded for I32 for now headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, @@ -438,7 +453,19 @@ NumpyUtilities::NPError // Add shape contents (if any - as this will be empty for rank 0) for (i = 0; i < shape.size(); i++) { - headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); + // Output NumPy file from tosa_refmodel_sut_run generates the shape information + // without a trailing comma when the rank is greater than 1. + if (i == 0) + { + if (shape.size() == 1) + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]); + else + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]); + } + else + { + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]); + } } // Close off the dictionary |