aboutsummaryrefslogtreecommitdiff
path: root/include/numpy_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/numpy_utils.h')
-rw-r--r--include/numpy_utils.h122
1 files changed, 76 insertions, 46 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index e9c4bb4..83fbd5c 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -40,56 +40,86 @@ public:
DATA_TYPE_NOT_SUPPORTED,
};
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf);
-
- static NPError
- writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const half_float::half* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf);
+ template <typename T>
+ static const char* getDTypeString(bool& is_bool)
+ {
+ is_bool = false;
+ if (std::is_same<T, bool>::value)
+ {
+ is_bool = true;
+ return "'|b1'";
+ }
+ if (std::is_same<T, uint8_t>::value)
+ {
+ return "'|u1'";
+ }
+ if (std::is_same<T, int8_t>::value)
+ {
+ return "'|i1'";
+ }
+ if (std::is_same<T, uint16_t>::value)
+ {
+ return "'<u2'";
+ }
+ if (std::is_same<T, int16_t>::value)
+ {
+ return "'<i2'";
+ }
+ if (std::is_same<T, int32_t>::value)
+ {
+ return "'<i4'";
+ }
+ if (std::is_same<T, int64_t>::value)
+ {
+ return "'<i8'";
+ }
+ if (std::is_same<T, float>::value)
+ {
+ return "'<f4'";
+ }
+ if (std::is_same<T, double>::value)
+ {
+ return "'<f8'";
+ }
+ if (std::is_same<T, half_float::half>::value)
+ {
+ return "'<f2'";
+ }
+ assert(false && "unsupported Dtype");
+ };
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf);
+ template <typename T>
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf)
+ {
+ std::vector<int32_t> shape = { static_cast<int32_t>(elems) };
+ return writeToNpyFile(filename, shape, databuf);
+ }
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf);
+ template <typename T>
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf)
+ {
+ bool is_bool;
+ const char* dtype_str = getDTypeString<T>(is_bool);
+ return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool);
+ }
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf);
+ template <typename T>
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf)
+ {
+ bool is_bool;
+ const char* dtype_str = getDTypeString<T>(is_bool);
+ return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool);
+ }
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf);
+ template <typename D, typename S>
+ static void copyBufferByElement(D* dest_buf, S* src_buf, int num)
+ {
+ static_assert(sizeof(D) >= sizeof(S));
+ for (int i = 0; i < num; ++i)
+ {
+ dest_buf[i] = src_buf[i];
+ }
+ }
private:
static NPError writeToNpyFileCommon(const char* filename,