aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2023-07-31 15:15:12 -0700
committerEric Kunze <eric.kunze@arm.com>2023-08-02 16:59:13 +0000
commit679bdadb6b51b14013a00588cec2452d6ee1d1ac (patch)
tree34f71ccf0db10fbea332ae1f5972eabb036422b6 /include
parente2b20e4eb9e91d9fd4a155880f3bf6085b8ffaac (diff)
downloadserialization_lib-679bdadb6b51b14013a00588cec2452d6ee1d1ac.tar.gz
Simplify overloaded writeToNpyFiles and readFromNpyFiles
templatize these functions instead to reduce redundant code. Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: Ie8b6f7d2b489c3508fea72481ce38f0db6d0c490
Diffstat (limited to 'include')
-rw-r--r--include/numpy_utils.h122
1 files changed, 76 insertions, 46 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index e9c4bb4..83fbd5c 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -40,56 +40,86 @@ public:
DATA_TYPE_NOT_SUPPORTED,
};
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf);
-
- static NPError
- writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const half_float::half* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf);
+ template <typename T>
+ static const char* getDTypeString(bool& is_bool)
+ {
+ is_bool = false;
+ if (std::is_same<T, bool>::value)
+ {
+ is_bool = true;
+ return "'|b1'";
+ }
+ if (std::is_same<T, uint8_t>::value)
+ {
+ return "'|u1'";
+ }
+ if (std::is_same<T, int8_t>::value)
+ {
+ return "'|i1'";
+ }
+ if (std::is_same<T, uint16_t>::value)
+ {
+ return "'<u2'";
+ }
+ if (std::is_same<T, int16_t>::value)
+ {
+ return "'<i2'";
+ }
+ if (std::is_same<T, int32_t>::value)
+ {
+ return "'<i4'";
+ }
+ if (std::is_same<T, int64_t>::value)
+ {
+ return "'<i8'";
+ }
+ if (std::is_same<T, float>::value)
+ {
+ return "'<f4'";
+ }
+ if (std::is_same<T, double>::value)
+ {
+ return "'<f8'";
+ }
+ if (std::is_same<T, half_float::half>::value)
+ {
+ return "'<f2'";
+ }
+ assert(false && "unsupported Dtype");
+ };
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf);
+ template <typename T>
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf)
+ {
+ std::vector<int32_t> shape = { static_cast<int32_t>(elems) };
+ return writeToNpyFile(filename, shape, databuf);
+ }
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf);
+ template <typename T>
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf)
+ {
+ bool is_bool;
+ const char* dtype_str = getDTypeString<T>(is_bool);
+ return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool);
+ }
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf);
+ template <typename T>
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf)
+ {
+ bool is_bool;
+ const char* dtype_str = getDTypeString<T>(is_bool);
+ return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool);
+ }
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf);
+ template <typename D, typename S>
+ static void copyBufferByElement(D* dest_buf, S* src_buf, int num)
+ {
+ static_assert(sizeof(D) >= sizeof(S));
+ for (int i = 0; i < num; ++i)
+ {
+ dest_buf[i] = src_buf[i];
+ }
+ }
private:
static NPError writeToNpyFileCommon(const char* filename,