diff options
Diffstat (limited to 'include/numpy_utils.h')
-rw-r--r-- | include/numpy_utils.h | 104 |
1 files changed, 84 insertions, 20 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h index c64bc17..60cf77e 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ #include <cstring> #include <vector> +#include "half.hpp" + class NumpyUtilities { public: @@ -35,31 +37,89 @@ public: FILE_TYPE_MISMATCH, HEADER_PARSE_ERROR, BUFFER_SIZE_MISMATCH, + DATA_TYPE_NOT_SUPPORTED, }; - static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf); + template <typename T> + static const char* getDTypeString(bool& is_bool) + { + is_bool = false; + if (std::is_same<T, bool>::value) + { + is_bool = true; + return "'|b1'"; + } + if (std::is_same<T, uint8_t>::value) + { + return "'|u1'"; + } + if (std::is_same<T, int8_t>::value) + { + return "'|i1'"; + } + if (std::is_same<T, uint16_t>::value) + { + return "'<u2'"; + } + if (std::is_same<T, int16_t>::value) + { + return "'<i2'"; + } + if (std::is_same<T, int32_t>::value) + { + return "'<i4'"; + } + if (std::is_same<T, int64_t>::value) + { + return "'<i8'"; + } + if (std::is_same<T, float>::value) + { + return "'<f4'"; + } + if (std::is_same<T, double>::value) + { + return "'<f8'"; + } + if (std::is_same<T, half_float::half>::value) + { + return "'<f2'"; + } + assert(false && "unsupported Dtype"); + }; - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf); + template <typename T> + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf) + { + std::vector<int32_t> shape = { static_cast<int32_t>(elems) }; + return writeToNpyFile(filename, shape, databuf); + } - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf); + template <typename T> + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf) + { + bool is_bool; + const char* dtype_str = getDTypeString<T>(is_bool); + return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool); + } - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf); + template <typename T> + static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf) + { + bool is_bool; + const char* dtype_str = getDTypeString<T>(is_bool); + return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool); + } - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf); + template <typename D, typename S> + static void copyBufferByElement(D* dest_buf, S* src_buf, int num) + { + static_assert(sizeof(D) >= sizeof(S), "The size of dest_buf must be equal to or larger than that of src_buf"); + for (int i = 0; i < num; ++i) + { + dest_buf[i] = src_buf[i]; + } + } private: static NPError writeToNpyFileCommon(const char* filename, @@ -75,7 +135,11 @@ private: void* databuf, bool bool_translate); static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str); + static NPError getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order); static NPError writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str); }; +template <> +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf); + #endif // _TOSA_NUMPY_UTILS_H |