aboutsummaryrefslogtreecommitdiff
path: root/src/numpy_utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/numpy_utils.cpp')
-rw-r--r--src/numpy_utils.cpp197
1 files changed, 119 insertions, 78 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index 80c680f..e4171d7 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -14,6 +14,9 @@
// limitations under the License.
#include "numpy_utils.h"
+#include "half.hpp"
+#include <algorithm>
+#include <memory>
// Magic NUMPY header
static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
@@ -21,28 +24,81 @@ static const int NUMPY_HEADER_SZ = 128;
// Maximum shape dimensions supported
static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
-{
- const char dtype_str[] = "'|b1'";
- return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
-}
-
+// This is an entry function for reading 8-/16-/32-bit npy file.
+template <>
NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
{
- const char dtype_str[] = "'<i4'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
-}
+ FILE* infile = nullptr;
+ NPError rc = HEADER_PARSE_ERROR;
+ assert(filename);
+ assert(databuf);
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
-{
- const char dtype_str[] = "'<i8'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
-}
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ return FILE_NOT_FOUND;
+ }
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
-{
- const char dtype_str[] = "'<f4'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
+ bool is_signed = false;
+ int length_per_byte = 0;
+ char byte_order;
+ rc = getHeader(infile, is_signed, length_per_byte, byte_order);
+ if (rc != NO_ERROR)
+ return rc;
+
+ switch (length_per_byte)
+ {
+ case 1:
+ if (is_signed)
+ {
+ int8_t* tmp_buf = new int8_t[elems];
+ rc = readFromNpyFile<int8_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ delete[] tmp_buf;
+ }
+ else
+ {
+ uint8_t* tmp_buf = new uint8_t[elems];
+ rc = readFromNpyFile<uint8_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ delete[] tmp_buf;
+ }
+ break;
+ case 2:
+ if (is_signed)
+ {
+ int16_t* tmp_buf = new int16_t[elems];
+ rc = readFromNpyFile<int16_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ delete[] tmp_buf;
+ }
+ else
+ {
+ uint16_t* tmp_buf = new uint16_t[elems];
+ rc = readFromNpyFile<uint16_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ delete[] tmp_buf;
+ }
+ break;
+ case 4:
+ if (is_signed)
+ {
+ bool is_bool;
+ const char* dtype_str = getDTypeString<int32_t>(is_bool);
+ rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool);
+ }
+ else
+ {
+ // uint32, not supported
+ rc = DATA_TYPE_NOT_SUPPORTED;
+ }
+ break;
+ default:
+ return DATA_TYPE_NOT_SUPPORTED;
+ break;
+ }
+
+ return rc;
}
NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
@@ -101,6 +157,46 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filena
return rc;
}
+NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
+{
+ char buf[NUMPY_HEADER_SZ + 1];
+ NPError rc = NO_ERROR;
+ assert(infile);
+
+ if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
+ {
+ return HEADER_PARSE_ERROR;
+ }
+
+ // Validate the numpy magic number
+ if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
+ {
+ return HEADER_PARSE_ERROR;
+ }
+
+ std::string dic_string(buf, NUMPY_HEADER_SZ);
+
+ std::string desc_str("descr':");
+ size_t offset = dic_string.find(desc_str);
+ if (offset == std::string::npos)
+ return HEADER_PARSE_ERROR;
+
+ offset += desc_str.size() + 1;
+ // Skip whitespace and the opening '
+ while (offset < dic_string.size() && (std::isspace(dic_string[offset]) || dic_string[offset] == '\''))
+ offset++;
+ // Check for overflow
+ if (offset + 2 > dic_string.size())
+ return HEADER_PARSE_ERROR;
+
+ byte_order = dic_string[offset];
+ is_signed = dic_string[offset + 1] == 'u' ? false : true;
+ bit_length = (int)dic_string[offset + 2] - '0';
+
+ rewind(infile);
+ return rc;
+}
+
NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
{
char buf[NUMPY_HEADER_SZ + 1];
@@ -253,60 +349,6 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3
return rc;
}
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
-{
- const char dtype_str[] = "'|b1'";
- return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
-{
- const char dtype_str[] = "'<i4'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
-{
- const char dtype_str[] = "'<i8'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
-{
- const char dtype_str[] = "'<f4'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
-}
-
NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
const char* dtype_str,
const size_t elementsize,
@@ -390,12 +432,11 @@ NumpyUtilities::NPError
// Output the format dictionary
// Hard-coded for I32 for now
- headerPos +=
- snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
- dtype_str, shape.empty() ? 1 : shape[0]);
+ headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
+ "'descr': %s, 'fortran_order': False, 'shape': (", dtype_str);
- // Remainder of shape array
- for (i = 1; i < shape.size(); i++)
+ // Add shape contents (if any - as this will be empty for rank 0)
+ for (i = 0; i < shape.size(); i++)
{
headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
}