aboutsummaryrefslogtreecommitdiff
path: root/src/numpy_utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/numpy_utils.cpp')
-rw-r--r--src/numpy_utils.cpp29
1 files changed, 28 insertions, 1 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index e4171d7..7cf5f94 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -247,6 +247,14 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3
while (isspace(*ptr))
ptr++;
+ // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but
+ // default NumPy does not understand this notation, which causes trouble
+ // when other code tries to open this file.
+ // To avoid this, '|u1' notation is used when the file is written, and the uint8
+ // data is viewed as float8_e5m2 later when the file is read.
+ if (!strcmp(dtype_str, "'<f1'"))
+ dtype_str = "'|u1'";
+
if (strcmp(ptr, dtype_str))
{
return FILE_TYPE_MISMATCH;
@@ -430,6 +438,13 @@ NumpyUtilities::NPError
memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
headerPos += sizeof(NUMPY_HEADER_STR) - 1;
+ // NumPy does not understand float8_e5m2, so change it to uint8 type, so that
+ // Python can read .npy files.
+ if (!strcmp(dtype_str, "'<f1'"))
+ {
+ dtype_str = "'|u1'";
+ }
+
// Output the format dictionary
// Hard-coded for I32 for now
headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
@@ -438,7 +453,19 @@ NumpyUtilities::NPError
// Add shape contents (if any - as this will be empty for rank 0)
for (i = 0; i < shape.size(); i++)
{
- headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
+ // Output NumPy file from tosa_refmodel_sut_run generates the shape information
+ // without a trailing comma when the rank is greater than 1.
+ if (i == 0)
+ {
+ if (shape.size() == 1)
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]);
+ else
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]);
+ }
+ else
+ {
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]);
+ }
}
// Close off the dictionary