diff options
-rw-r--r-- | include/numpy_utils.h | 122 | ||||
-rw-r--r-- | src/numpy_utils.cpp | 259 |
2 files changed, 106 insertions, 275 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h index e9c4bb4..83fbd5c 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -40,56 +40,86 @@ public: DATA_TYPE_NOT_SUPPORTED, }; - static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf); - - static NPError - writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const half_float::half* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf); + template <typename T> + static const char* getDTypeString(bool& is_bool) + { + is_bool = false; + if (std::is_same<T, bool>::value) + { + is_bool = true; + return "'|b1'"; + } + if (std::is_same<T, uint8_t>::value) + { + return "'|u1'"; + } + if (std::is_same<T, int8_t>::value) + { + return "'|i1'"; + } + if (std::is_same<T, uint16_t>::value) + { + return "'<u2'"; + } + if (std::is_same<T, int16_t>::value) + { + return "'<i2'"; + } + if (std::is_same<T, int32_t>::value) + { + return "'<i4'"; + } + if (std::is_same<T, int64_t>::value) + { + return "'<i8'"; + } + if (std::is_same<T, float>::value) + { + return "'<f4'"; + } + if (std::is_same<T, double>::value) + { + return "'<f8'"; + } + if (std::is_same<T, half_float::half>::value) + { + return "'<f2'"; + } + assert(false && "unsupported Dtype"); + }; - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf); + template <typename T> + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf) + { + std::vector<int32_t> shape = { static_cast<int32_t>(elems) }; + return writeToNpyFile(filename, shape, databuf); + } - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf); + template <typename T> + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf) + { + bool is_bool; + const char* dtype_str = getDTypeString<T>(is_bool); + return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool); + } - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf); + template <typename T> + static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf) + { + bool is_bool; + const char* dtype_str = getDTypeString<T>(is_bool); + return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool); + } - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf); + template <typename D, typename S> + static void copyBufferByElement(D* dest_buf, S* src_buf, int num) + { + static_assert(sizeof(D) >= sizeof(S)); + for (int i = 0; i < num; ++i) + { + dest_buf[i] = src_buf[i]; + } + } private: static NPError writeToNpyFileCommon(const char* filename, diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index 0002fd9..64460bd 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -16,6 +16,7 @@ #include "numpy_utils.h" #include "half.hpp" #include <algorithm> + // Magic NUMPY header static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{"; static const int NUMPY_HEADER_SZ = 128; @@ -24,20 +25,10 @@ static const int NUMPY_MAX_DIMS_SUPPORTED = 10; // Offset for NUMPY header desc dictionary string static const int NUMPY_HEADER_DESC_OFFSET = 8; -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf) -{ - const char dtype_str[] = "'|b1'"; - return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true); -} - +// This is an entry function for reading 8-/16-/32-bit npy file. +template <> NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf) { - const char dtype_str_uint8[] = "'|u1'"; - const char dtype_str_int8[] = "'|i1'"; - const char dtype_str_uint16[] = "'<u2'"; - const char dtype_str_int16[] = "'<i2'"; - const char dtype_str_int32[] = "'<i4'"; - FILE* infile = nullptr; NPError rc = HEADER_PARSE_ERROR; assert(filename); @@ -49,91 +40,58 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co return FILE_NOT_FOUND; } - bool is_signed = false; - int bit_length; + bool is_signed = false; + int length_per_byte = 0; char byte_order; - rc = getHeader(infile, is_signed, bit_length, byte_order); + rc = getHeader(infile, is_signed, length_per_byte, byte_order); if (rc != NO_ERROR) return rc; - switch (bit_length) + switch (length_per_byte) { - case 1: // 8-bit + case 1: if (is_signed) { - // int8 - int8_t* i8databuf = nullptr; - i8databuf = (int8_t*)calloc(sizeof(i8databuf), elems); - - rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false); - - for (unsigned i = 0; i < elems; ++i) - { - databuf[i] = (int32_t)i8databuf[i]; - } - free(i8databuf); - - return rc; + int8_t* tmp_buf = new int8_t[elems]; + rc = readFromNpyFile<int8_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + free(tmp_buf); } else { - // uint8 - uint8_t* ui8databuf = nullptr; - ui8databuf = (uint8_t*)calloc(sizeof(ui8databuf), elems); - - rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false); - - for (unsigned i = 0; i < elems; ++i) - { - databuf[i] = (int32_t)ui8databuf[i]; - } - free(ui8databuf); + uint8_t* tmp_buf = new uint8_t[elems]; + rc = readFromNpyFile<uint8_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + free(tmp_buf); } break; - case 2: // 16-bit + case 2: if (is_signed) { - // int16 - int16_t* i16databuf = nullptr; - i16databuf = (int16_t*)calloc(sizeof(i16databuf), elems); - - rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false); - - for (unsigned i = 0; i < elems; ++i) - { - databuf[i] = (int32_t)i16databuf[i]; - } - free(i16databuf); - - return rc; + int16_t* tmp_buf = new int16_t[elems]; + rc = readFromNpyFile<int16_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + free(tmp_buf); } else { - // uint16 - uint16_t* ui16databuf = nullptr; - ui16databuf = (uint16_t*)calloc(sizeof(ui16databuf), elems); - - rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false); - - for (unsigned i = 0; i < elems; ++i) - { - databuf[i] = (int32_t)ui16databuf[i]; - } - free(ui16databuf); - - return rc; + uint16_t* tmp_buf = new uint16_t[elems]; + rc = readFromNpyFile<uint16_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + free(tmp_buf); } break; - case 4: // 32-bit + case 4: if (is_signed) { - // int32 - return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false); + bool is_bool; + const char* dtype_str = getDTypeString<int32_t>(is_bool); + rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool); } else { // uint32, not supported - return DATA_TYPE_NOT_SUPPORTED; + rc = DATA_TYPE_NOT_SUPPORTED; } break; default: @@ -144,31 +102,6 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co return rc; } -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf) -{ - const char dtype_str[] = "'<i8'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false); -} - -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf) -{ - const char dtype_str[] = "'<f4'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false); -} - -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf) -{ - const char dtype_str[] = "'<f8'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf) -{ - const char dtype_str[] = "'<f2'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false); -} - NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename, const char* dtype_str, const size_t elementsize, @@ -418,138 +351,6 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3 return rc; } -NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf) -{ - const char dtype_str[] = "'|b1'"; - return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1 -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf) -{ - const char dtype_str[] = "'|u1'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(uint8_t), shape, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf) -{ - const char dtype_str[] = "'|i1'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(int8_t), shape, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf) -{ - const char dtype_str[] = "'<u2'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(uint16_t), shape, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf) -{ - const char dtype_str[] = "'<i2'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(int16_t), shape, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf) -{ - const char dtype_str[] = "'<i4'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf) -{ - const char dtype_str[] = "'<i8'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false); -} - -NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf) -{ - const char dtype_str[] = "'<f4'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf) -{ - const char dtype_str[] = "'<f8'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false); -} - -NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, - const std::vector<int32_t>& shape, - const half_float::half* databuf) -{ - const char dtype_str[] = "'<f2'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false); -} - NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename, const char* dtype_str, const size_t elementsize, |