From 679bdadb6b51b14013a00588cec2452d6ee1d1ac Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Mon, 31 Jul 2023 15:15:12 -0700 Subject: Simplify overloaded writeToNpyFiles and readFromNpyFiles templatize these functions instead to reduce redundant code. Signed-off-by: TatWai Chong Change-Id: Ie8b6f7d2b489c3508fea72481ce38f0db6d0c490 --- include/numpy_utils.h | 122 +++++++++++++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 46 deletions(-) (limited to 'include') diff --git a/include/numpy_utils.h b/include/numpy_utils.h index e9c4bb4..83fbd5c 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -40,56 +40,86 @@ public: DATA_TYPE_NOT_SUPPORTED, }; - static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const bool* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf); - - static NPError - writeToNpyFile(const char* filename, const std::vector& shape, const half_float::half* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const uint8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const uint16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int32_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int64_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf); + template + static const char* getDTypeString(bool& is_bool) + { + is_bool = false; + if (std::is_same::value) + { + is_bool = true; + return "'|b1'"; + } + if (std::is_same::value) + { + return "'|u1'"; + } + if (std::is_same::value) + { + return "'|i1'"; + } + if (std::is_same::value) + { + return "'::value) + { + return "'::value) + { + return "'::value) + { + return "'::value) + { + return "'::value) + { + return "'::value) + { + return "'& shape, const float* databuf); + template + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf) + { + std::vector shape = { static_cast(elems) }; + return writeToNpyFile(filename, shape, databuf); + } - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf); + template + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const T* databuf) + { + bool is_bool; + const char* dtype_str = getDTypeString(is_bool); + return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool); + } - static NPError writeToNpyFile(const char* filename, const std::vector& shape, const double* databuf); + template + static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf) + { + bool is_bool; + const char* dtype_str = getDTypeString(is_bool); + return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool); + } - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf); + template + static void copyBufferByElement(D* dest_buf, S* src_buf, int num) + { + static_assert(sizeof(D) >= sizeof(S)); + for (int i = 0; i < num; ++i) + { + dest_buf[i] = src_buf[i]; + } + } private: static NPError writeToNpyFileCommon(const char* filename, -- cgit v1.2.1