aboutsummaryrefslogtreecommitdiff
path: root/src/numpy_utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/numpy_utils.cpp')
-rw-r--r--src/numpy_utils.cpp16
1 files changed, 16 insertions, 0 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index 80c680f..c770d45 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -14,6 +14,7 @@
// limitations under the License.
#include "numpy_utils.h"
+#include "half.hpp"
// Magic NUMPY header
static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
@@ -45,6 +46,13 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co
return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
}
+NumpyUtilities::NPError
+ NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
+{
+ const char dtype_str[] = "'<f2'";
+ return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
+}
+
NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
const char* dtype_str,
const size_t elementsize,
@@ -307,6 +315,14 @@ NumpyUtilities::NPError
return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
}
+NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
+ const std::vector<int32_t>& shape,
+ const half_float::half* databuf)
+{
+ const char dtype_str[] = "'<f2'";
+ return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
+}
+
NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
const char* dtype_str,
const size_t elementsize,