aboutsummaryrefslogtreecommitdiff
path: root/src/numpy_utils.cpp
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-04-29 23:57:27 +0000
committerWon Jeon <won.jeon@arm.com>2024-05-03 13:33:29 +0000
commita814152b68a286f5bb9ddc095bb1897ec0e3d8ff (patch)
treec8aa9a42e3d9fdf978e5d366a301b1f8d9716d83 /src/numpy_utils.cpp
parent3aebe2bd863d6e0cb82171984cd49e5ad516d0db (diff)
downloadserialization_lib-a814152b68a286f5bb9ddc095bb1897ec0e3d8ff.tar.gz
Use native size of Bfloat16 and Float8 for serialization/deserialization
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I0d2075f90988d4fd1139a11b5c154bdd600bb2cd
Diffstat (limited to 'src/numpy_utils.cpp')
-rw-r--r--src/numpy_utils.cpp29
1 files changed, 28 insertions, 1 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index e4171d7..7cf5f94 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3
while (isspace(*ptr))
ptr++;
+ // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but
+ // default NumPy does not understand this notation, which causes trouble
+ // when other code tries to open this file.
+ // To avoid this, '|u1' notation is used when the file is written, and the uint8
+ // data is viewed as float8_e5m2 later when the file is read.
+ if (!strcmp(dtype_str, "'<f1'"))
+ dtype_str = "'|u1'";
+
if (strcmp(ptr, dtype_str))
{
return FILE_TYPE_MISMATCH;
@@ -430,6 +438,13 @@ NumpyUtilities::NPError
memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
headerPos += sizeof(NUMPY_HEADER_STR) - 1;
+ // NumPy does not understand float8_e5m2, so change it to uint8 type, so that
+ // Python can read .npy files.
+ if (!strcmp(dtype_str, "'<f1'"))
+ {
+ dtype_str = "'|u1'";
+ }
+
// Output the format dictionary
// Hard-coded for I32 for now
headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
@@ -438,7 +453,19 @@ NumpyUtilities::NPError
// Add shape contents (if any - as this will be empty for rank 0)
for (i = 0; i < shape.size(); i++)
{
- headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
+ // Output NumPy file from tosa_refmodel_sut_run generates the shape information
+ // without a trailing comma when the rank is greater than 1.
+ if (i == 0)
+ {
+ if (shape.size() == 1)
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]);
+ else
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]);
+ }
+ else
+ {
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]);
+ }
}
// Close off the dictionary