diff options
Diffstat (limited to 'src/numpy_utils.cpp')
-rw-r--r-- | src/numpy_utils.cpp | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index c770d45..123908a 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -46,6 +46,12 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false); } +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf) +{ + const char dtype_str[] = "'<f8'"; + return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false); +} + NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf) { @@ -315,6 +321,20 @@ NumpyUtilities::NPError return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false); } +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf) +{ + std::vector<int32_t> shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf) +{ + const char dtype_str[] = "'<f8'"; + return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false); +} + NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const half_float::half* databuf) |