aboutsummaryrefslogtreecommitdiff
path: root/include/libnpy/npy.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/libnpy/npy.hpp')
-rw-r--r--include/libnpy/npy.hpp504
1 files changed, 272 insertions, 232 deletions
diff --git a/include/libnpy/npy.hpp b/include/libnpy/npy.hpp
index 24244ca272..4399426de1 100644
--- a/include/libnpy/npy.hpp
+++ b/include/libnpy/npy.hpp
@@ -20,8 +20,8 @@
SOFTWARE.
*/
-#ifndef NPY_H
-#define NPY_H
+#ifndef NPY_HPP_
+#define NPY_HPP_
#include <complex>
#include <fstream>
@@ -30,18 +30,23 @@
#include <sstream>
#include <cstdint>
#include <cstring>
+#include <array>
#include <vector>
#include <stdexcept>
#include <algorithm>
-#include <regex>
#include <unordered_map>
+#include <type_traits>
+#include <typeinfo>
+#include <typeindex>
+#include <iterator>
+#include <utility>
namespace npy {
/* Compile-time test for byte order.
If your compiler does not define these per default, you may want to define
- one of these constants manually.
+ one of these constants manually.
Defaults to little endian order. */
#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \
defined(__BIG_ENDIAN__) || \
@@ -62,94 +67,123 @@ const char little_endian_char = '<';
const char big_endian_char = '>';
const char no_endian_char = '|';
-constexpr char host_endian_char = ( big_endian ?
- big_endian_char :
- little_endian_char );
+constexpr std::array<char, 3>
+endian_chars = {little_endian_char, big_endian_char, no_endian_char};
+constexpr std::array<char, 4>
+numtype_chars = {'f', 'i', 'u', 'c'};
+
+constexpr char host_endian_char = (big_endian ?
+ big_endian_char :
+ little_endian_char);
/* npy array length */
typedef unsigned long int ndarray_len_t;
-inline void write_magic(std::ostream& ostream, unsigned char v_major=1, unsigned char v_minor=0) {
+typedef std::pair<char, char> version_t;
+
+struct dtype_t {
+ const char byteorder;
+ const char kind;
+ const unsigned int itemsize;
+
+// TODO(llohse): implement as constexpr
+ inline std::string str() const {
+ const size_t max_buflen = 16;
+ char buf[max_buflen];
+ std::snprintf(buf, max_buflen, "%c%c%u", byteorder, kind, itemsize);
+ return std::string(buf);
+ }
+
+ inline std::tuple<const char, const char, const unsigned int> tie() const {
+ return std::tie(byteorder, kind, itemsize);
+ }
+};
+
+
+struct header_t {
+ const dtype_t dtype;
+ const bool fortran_order;
+ const std::vector <ndarray_len_t> shape;
+};
+
+inline void write_magic(std::ostream &ostream, version_t version) {
ostream.write(magic_string, magic_string_length);
- ostream.put(v_major);
- ostream.put(v_minor);
+ ostream.put(version.first);
+ ostream.put(version.second);
}
-inline void read_magic(std::istream& istream, unsigned char& v_major, unsigned char& v_minor) {
- char buf[magic_string_length+2];
- istream.read(buf, magic_string_length+2);
+inline version_t read_magic(std::istream &istream) {
+ char buf[magic_string_length + 2];
+ istream.read(buf, magic_string_length + 2);
- if(!istream) {
+ if (!istream) {
throw std::runtime_error("io error: failed reading file");
}
if (0 != std::memcmp(buf, magic_string, magic_string_length))
throw std::runtime_error("this file does not have a valid npy format.");
- v_major = buf[magic_string_length];
- v_minor = buf[magic_string_length+1];
-}
+ version_t version;
+ version.first = buf[magic_string_length];
+ version.second = buf[magic_string_length + 1];
-// typestring magic
-struct Typestring {
- private:
- char c_endian;
- char c_type;
- int len;
-
- public:
- inline std::string str() {
- const size_t max_buflen = 16;
- char buf[max_buflen];
- std::sprintf(buf, "%c%c%u", c_endian, c_type, len);
- return std::string(buf);
- }
+ return version;
+}
- Typestring(const std::vector<float>& v)
- :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(float)} {}
- Typestring(const std::vector<double>& v)
- :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(double)} {}
- Typestring(const std::vector<long double>& v)
- :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(long double)} {}
-
- Typestring(const std::vector<char>& v)
- :c_endian {no_endian_char}, c_type {'i'}, len {sizeof(char)} {}
- Typestring(const std::vector<short>& v)
- :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(short)} {}
- Typestring(const std::vector<int>& v)
- :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(int)} {}
- Typestring(const std::vector<long>& v)
- :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long)} {}
- Typestring(const std::vector<long long>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long long)} {}
-
- Typestring(const std::vector<unsigned char>& v)
- :c_endian {no_endian_char}, c_type {'u'}, len {sizeof(unsigned char)} {}
- Typestring(const std::vector<unsigned short>& v)
- :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned short)} {}
- Typestring(const std::vector<unsigned int>& v)
- :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned int)} {}
- Typestring(const std::vector<unsigned long>& v)
- :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long)} {}
- Typestring(const std::vector<unsigned long long>& v)
- :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long long)} {}
-
- Typestring(const std::vector<std::complex<float>>& v)
- :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<float>)} {}
- Typestring(const std::vector<std::complex<double>>& v)
- :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<double>)} {}
- Typestring(const std::vector<std::complex<long double>>& v)
- :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<long double>)} {}
+const std::unordered_map<std::type_index, dtype_t> dtype_map = {
+ {std::type_index(typeid(float)), {host_endian_char, 'f', sizeof(float)}},
+ {std::type_index(typeid(double)), {host_endian_char, 'f', sizeof(double)}},
+ {std::type_index(typeid(long double)), {host_endian_char, 'f', sizeof(long double)}},
+ {std::type_index(typeid(char)), {no_endian_char, 'i', sizeof(char)}},
+ {std::type_index(typeid(signed char)), {no_endian_char, 'i', sizeof(signed char)}},
+ {std::type_index(typeid(short)), {host_endian_char, 'i', sizeof(short)}},
+ {std::type_index(typeid(int)), {host_endian_char, 'i', sizeof(int)}},
+ {std::type_index(typeid(long)), {host_endian_char, 'i', sizeof(long)}},
+ {std::type_index(typeid(long long)), {host_endian_char, 'i', sizeof(long long)}},
+ {std::type_index(typeid(unsigned char)), {no_endian_char, 'u', sizeof(unsigned char)}},
+ {std::type_index(typeid(unsigned short)), {host_endian_char, 'u', sizeof(unsigned short)}},
+ {std::type_index(typeid(unsigned int)), {host_endian_char, 'u', sizeof(unsigned int)}},
+ {std::type_index(typeid(unsigned long)), {host_endian_char, 'u', sizeof(unsigned long)}},
+ {std::type_index(typeid(unsigned long long)), {host_endian_char, 'u', sizeof(unsigned long long)}},
+ {std::type_index(typeid(std::complex<float>)), {host_endian_char, 'c', sizeof(std::complex<float>)}},
+ {std::type_index(typeid(std::complex<double>)), {host_endian_char, 'c', sizeof(std::complex<double>)}},
+ {std::type_index(typeid(std::complex<long double>)), {host_endian_char, 'c', sizeof(std::complex<long double>)}}
};
-inline void parse_typestring( std::string typestring){
- std::regex re ("'([<>|])([ifuc])(\\d+)'");
- std::smatch sm;
- std::regex_match(typestring, sm, re );
+// helpers
+inline bool is_digits(const std::string &str) {
+ return std::all_of(str.begin(), str.end(), ::isdigit);
+}
+
+template<typename T, size_t N>
+inline bool in_array(T val, const std::array <T, N> &arr) {
+ return std::find(std::begin(arr), std::end(arr), val) != std::end(arr);
+}
+
+inline dtype_t parse_descr(std::string typestring) {
+ if (typestring.length() < 3) {
+ throw std::runtime_error("invalid typestring (length)");
+ }
+
+ char byteorder_c = typestring.at(0);
+ char kind_c = typestring.at(1);
+ std::string itemsize_s = typestring.substr(2);
+
+ if (!in_array(byteorder_c, endian_chars)) {
+ throw std::runtime_error("invalid typestring (byteorder)");
+ }
+
+ if (!in_array(kind_c, numtype_chars)) {
+ throw std::runtime_error("invalid typestring (kind)");
+ }
- if ( sm.size() != 4 ) {
- throw std::runtime_error("invalid typestring");
+ if (!is_digits(itemsize_s)) {
+ throw std::runtime_error("invalid typestring (itemsize)");
}
+ unsigned int itemsize = std::stoul(itemsize_s);
+
+ return {byteorder_c, kind_c, itemsize};
}
namespace pyparse {
@@ -157,7 +191,7 @@ namespace pyparse {
/**
Removes leading and trailing whitespaces
*/
-inline std::string trim(const std::string& str) {
+inline std::string trim(const std::string &str) {
const std::string whitespace = " \t";
auto begin = str.find_first_not_of(whitespace);
@@ -166,16 +200,16 @@ inline std::string trim(const std::string& str) {
auto end = str.find_last_not_of(whitespace);
- return str.substr(begin, end-begin+1);
+ return str.substr(begin, end - begin + 1);
}
-inline std::string get_value_from_map(const std::string& mapstr) {
+inline std::string get_value_from_map(const std::string &mapstr) {
size_t sep_pos = mapstr.find_first_of(":");
if (sep_pos == std::string::npos)
return "";
- std::string tmp = mapstr.substr(sep_pos+1);
+ std::string tmp = mapstr.substr(sep_pos + 1);
return trim(tmp);
}
@@ -184,9 +218,8 @@ inline std::string get_value_from_map(const std::string& mapstr) {
The keys need to be known and may not appear anywhere else in the data.
*/
-inline std::unordered_map<std::string, std::string> parse_dict(std::string in, std::vector<std::string>& keys) {
-
- std::unordered_map<std::string, std::string> map;
+inline std::unordered_map <std::string, std::string> parse_dict(std::string in, const std::vector <std::string> &keys) {
+ std::unordered_map <std::string, std::string> map;
if (keys.size() == 0)
return map;
@@ -195,36 +228,36 @@ inline std::unordered_map<std::string, std::string> parse_dict(std::string in, s
// unwrap dictionary
if ((in.front() == '{') && (in.back() == '}'))
- in = in.substr(1, in.length()-2);
+ in = in.substr(1, in.length() - 2);
else
throw std::runtime_error("Not a Python dictionary.");
- std::vector<std::pair<size_t, std::string>> positions;
+ std::vector <std::pair<size_t, std::string>> positions;
- for (auto const& value : keys) {
- size_t pos = in.find( "'" + value + "'" );
+ for (auto const &value : keys) {
+ size_t pos = in.find("'" + value + "'");
if (pos == std::string::npos)
- throw std::runtime_error("Missing '"+value+"' key.");
+ throw std::runtime_error("Missing '" + value + "' key.");
- std::pair<size_t, std::string> position_pair { pos, value };
+ std::pair <size_t, std::string> position_pair{pos, value};
positions.push_back(position_pair);
}
// sort by position in dict
- std::sort(positions.begin(), positions.end() );
+ std::sort(positions.begin(), positions.end());
- for(size_t i = 0; i < positions.size(); ++i) {
+ for (size_t i = 0; i < positions.size(); ++i) {
std::string raw_value;
- size_t begin { positions[i].first };
- size_t end { std::string::npos };
+ size_t begin{positions[i].first};
+ size_t end{std::string::npos};
std::string key = positions[i].second;
- if ( i+1 < positions.size() )
- end = positions[i+1].first;
+ if (i + 1 < positions.size())
+ end = positions[i + 1].first;
- raw_value = in.substr(begin, end-begin);
+ raw_value = in.substr(begin, end - begin);
raw_value = trim(raw_value);
@@ -240,7 +273,7 @@ inline std::unordered_map<std::string, std::string> parse_dict(std::string in, s
/**
Parses the string representation of a Python boolean
*/
-inline bool parse_bool(const std::string& in) {
+inline bool parse_bool(const std::string &in) {
if (in == "True")
return true;
if (in == "False")
@@ -252,9 +285,9 @@ inline bool parse_bool(const std::string& in) {
/**
Parses the string representation of a Python str
*/
-inline std::string parse_str(const std::string& in) {
+inline std::string parse_str(const std::string &in) {
if ((in.front() == '\'') && (in.back() == '\''))
- return in.substr(1, in.length()-2);
+ return in.substr(1, in.length() - 2);
throw std::runtime_error("Invalid python string.");
}
@@ -262,30 +295,30 @@ inline std::string parse_str(const std::string& in) {
/**
Parses the string represenatation of a Python tuple into a vector of its items
*/
-inline std::vector<std::string> parse_tuple(std::string in) {
- std::vector<std::string> v;
+inline std::vector <std::string> parse_tuple(std::string in) {
+ std::vector <std::string> v;
const char seperator = ',';
in = trim(in);
if ((in.front() == '(') && (in.back() == ')'))
- in = in.substr(1, in.length()-2);
+ in = in.substr(1, in.length() - 2);
else
throw std::runtime_error("Invalid Python tuple.");
std::istringstream iss(in);
for (std::string token; std::getline(iss, token, seperator);) {
- v.push_back(token);
+ v.push_back(token);
}
return v;
}
-template <typename T>
-inline std::string write_tuple(const std::vector<T>& v) {
+template<typename T>
+inline std::string write_tuple(const std::vector <T> &v) {
if (v.size() == 0)
- return "";
+ return "()";
std::ostringstream ss;
@@ -295,7 +328,7 @@ inline std::string write_tuple(const std::vector<T>& v) {
const std::string delimiter = ", ";
// v.size() > 1
ss << "(";
- std::copy(v.begin(), v.end()-1, std::ostream_iterator<T>(ss, delimiter.c_str()));
+ std::copy(v.begin(), v.end() - 1, std::ostream_iterator<T>(ss, delimiter.c_str()));
ss << v.back();
ss << ")";
}
@@ -304,16 +337,16 @@ inline std::string write_tuple(const std::vector<T>& v) {
}
inline std::string write_boolean(bool b) {
- if(b)
+ if (b)
return "True";
else
return "False";
}
-} // namespace pyparse
+} // namespace pyparse
-inline void parse_header(std::string header, std::string& descr, bool& fortran_order, std::vector<ndarray_len_t>& shape) {
+inline header_t parse_header(std::string header) {
/*
The first 6 bytes are a magic string: exactly "x93NUMPY".
The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
@@ -337,7 +370,7 @@ inline void parse_header(std::string header, std::string& descr, bool& fortran_o
header.pop_back();
// parse the dictionary
- std::vector<std::string> keys { "descr", "fortran_order", "shape" };
+ std::vector <std::string> keys{"descr", "fortran_order", "shape"};
auto dict_map = npy::pyparse::parse_dict(header, keys);
if (dict_map.size() == 0)
@@ -347,173 +380,180 @@ inline void parse_header(std::string header, std::string& descr, bool& fortran_o
std::string fortran_s = dict_map["fortran_order"];
std::string shape_s = dict_map["shape"];
- // TODO: extract info from typestring
- parse_typestring(descr_s);
- // remove
- descr = npy::pyparse::parse_str(descr_s);
+ std::string descr = npy::pyparse::parse_str(descr_s);
+ dtype_t dtype = parse_descr(descr);
// convert literal Python bool to C++ bool
- fortran_order = npy::pyparse::parse_bool(fortran_s);
+ bool fortran_order = npy::pyparse::parse_bool(fortran_s);
// parse the shape tuple
auto shape_v = npy::pyparse::parse_tuple(shape_s);
- if (shape_v.size() == 0)
- throw std::runtime_error("invalid shape tuple in header");
-
- for ( auto item : shape_v ) {
- std::stringstream stream(item);
- unsigned long value;
- stream >> value;
- ndarray_len_t dim = static_cast<ndarray_len_t>(value);
+
+ std::vector <ndarray_len_t> shape;
+ for (auto item : shape_v) {
+ ndarray_len_t dim = static_cast<ndarray_len_t>(std::stoul(item));
shape.push_back(dim);
}
+
+ return {dtype, fortran_order, shape};
}
-inline std::string write_header_dict(const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape) {
- std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
- std::string shape_s = npy::pyparse::write_tuple(shape);
+inline std::string
+write_header_dict(const std::string &descr, bool fortran_order, const std::vector <ndarray_len_t> &shape) {
+ std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
+ std::string shape_s = npy::pyparse::write_tuple(shape);
- return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
+ return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
}
-inline void write_header(std::ostream& out, const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape_v)
-{
- std::string header_dict = write_header_dict(descr, fortran_order, shape_v);
+inline void write_header(std::ostream &out, const header_t &header) {
+ std::string header_dict = write_header_dict(header.dtype.str(), header.fortran_order, header.shape);
- size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
+ size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
- unsigned char version[2] = {1, 0};
- if (length >= 255*255) {
- length = magic_string_length + 2 + 4 + header_dict.length() + 1;
- version[0] = 2;
- version[1] = 0;
- }
- size_t padding_len = 16 - length % 16;
- std::string padding (padding_len, ' ');
-
- // write magic
- write_magic(out, version[0], version[1]);
-
- // write header length
- if (version[0] == 1 && version[1] == 0) {
- char header_len_le16[2];
- uint16_t header_len = header_dict.length() + padding.length() + 1;
-
- header_len_le16[0] = (header_len >> 0) & 0xff;
- header_len_le16[1] = (header_len >> 8) & 0xff;
- out.write(reinterpret_cast<char *>(header_len_le16), 2);
- }else{
- char header_len_le32[4];
- uint32_t header_len = header_dict.length() + padding.length() + 1;
-
- header_len_le32[0] = (header_len >> 0) & 0xff;
- header_len_le32[1] = (header_len >> 8) & 0xff;
- header_len_le32[2] = (header_len >> 16) & 0xff;
- header_len_le32[3] = (header_len >> 24) & 0xff;
- out.write(reinterpret_cast<char *>(header_len_le32), 4);
- }
+ version_t version{1, 0};
+ if (length >= 255 * 255) {
+ length = magic_string_length + 2 + 4 + header_dict.length() + 1;
+ version = {2, 0};
+ }
+ size_t padding_len = 16 - length % 16;
+ std::string padding(padding_len, ' ');
+
+ // write magic
+ write_magic(out, version);
+
+ // write header length
+ if (version == version_t{1, 0}) {
+ uint8_t header_len_le16[2];
+ uint16_t header_len = static_cast<uint16_t>(header_dict.length() + padding.length() + 1);
+
+ header_len_le16[0] = (header_len >> 0) & 0xff;
+ header_len_le16[1] = (header_len >> 8) & 0xff;
+ out.write(reinterpret_cast<char *>(header_len_le16), 2);
+ } else {
+ uint8_t header_len_le32[4];
+ uint32_t header_len = static_cast<uint32_t>(header_dict.length() + padding.length() + 1);
+
+ header_len_le32[0] = (header_len >> 0) & 0xff;
+ header_len_le32[1] = (header_len >> 8) & 0xff;
+ header_len_le32[2] = (header_len >> 16) & 0xff;
+ header_len_le32[3] = (header_len >> 24) & 0xff;
+ out.write(reinterpret_cast<char *>(header_len_le32), 4);
+ }
- out << header_dict << padding << '\n';
+ out << header_dict << padding << '\n';
}
-inline std::string read_header(std::istream& istream) {
- // check magic bytes an version number
- unsigned char v_major, v_minor;
- read_magic(istream, v_major, v_minor);
-
- uint32_t header_length;
- if(v_major == 1 && v_minor == 0){
-
- char header_len_le16[2];
- istream.read(header_len_le16, 2);
- header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
-
- if((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
- // TODO: display warning
- }
- }else if(v_major == 2 && v_minor == 0) {
- char header_len_le32[4];
- istream.read(header_len_le32, 4);
-
- header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
- | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
-
- if((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
- // TODO: display warning
- }
- }else{
- throw std::runtime_error("unsupported file format version");
+inline std::string read_header(std::istream &istream) {
+ // check magic bytes an version number
+ version_t version = read_magic(istream);
+
+ uint32_t header_length;
+ if (version == version_t{1, 0}) {
+ uint8_t header_len_le16[2];
+ istream.read(reinterpret_cast<char *>(header_len_le16), 2);
+ header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
+
+ if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
+ // TODO(llohse): display warning
}
+ } else if (version == version_t{2, 0}) {
+ uint8_t header_len_le32[4];
+ istream.read(reinterpret_cast<char *>(header_len_le32), 4);
- auto buf_v = std::vector<char>();
- buf_v.reserve(header_length);
- istream.read(buf_v.data(), header_length);
- std::string header(buf_v.data(), header_length);
+ header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
+ | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
- return header;
+ if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
+ // TODO(llohse): display warning
+ }
+ } else {
+ throw std::runtime_error("unsupported file format version");
+ }
+
+ auto buf_v = std::vector<char>(header_length);
+ istream.read(buf_v.data(), header_length);
+ std::string header(buf_v.data(), header_length);
+
+ return header;
}
-inline ndarray_len_t comp_size(const std::vector<ndarray_len_t>& shape) {
- ndarray_len_t size = 1;
- for (ndarray_len_t i : shape )
- size *= i;
+inline ndarray_len_t comp_size(const std::vector <ndarray_len_t> &shape) {
+ ndarray_len_t size = 1;
+ for (ndarray_len_t i : shape)
+ size *= i;
- return size;
+ return size;
}
template<typename Scalar>
-inline void SaveArrayAsNumpy( const std::string& filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[], const std::vector<Scalar>& data)
-{
- Typestring typestring_o(data);
- std::string typestring = typestring_o.str();
-
- std::ofstream stream( filename, std::ofstream::binary);
- if(!stream) {
- throw std::runtime_error("io error: failed to open a file.");
- }
+inline void
+SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
+ const Scalar* data) {
+// static_assert(has_typestring<Scalar>::value, "scalar type not understood");
+ const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
+
+ std::ofstream stream(filename, std::ofstream::binary);
+ if (!stream) {
+ throw std::runtime_error("io error: failed to open a file.");
+ }
- std::vector<ndarray_len_t> shape_v(shape, shape+n_dims);
- write_header(stream, typestring, fortran_order, shape_v);
+ std::vector <ndarray_len_t> shape_v(shape, shape + n_dims);
+ header_t header{dtype, fortran_order, shape_v};
+ write_header(stream, header);
- auto size = static_cast<size_t>(comp_size(shape_v));
+ auto size = static_cast<size_t>(comp_size(shape_v));
- stream.write(reinterpret_cast<const char*>(data.data()), sizeof(Scalar) * size);
+ stream.write(reinterpret_cast<const char *>(data), sizeof(Scalar) * size);
}
+template<typename Scalar>
+inline void
+SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
+ const std::vector <Scalar> &data) {
+ SaveArrayAsNumpy(filename, fortran_order, n_dims, shape, data.data());
+}
template<typename Scalar>
-inline void LoadArrayFromNumpy(const std::string& filename, std::vector<unsigned long>& shape, std::vector<Scalar>& data)
-{
- std::ifstream stream(filename, std::ifstream::binary);
- if(!stream) {
- throw std::runtime_error("io error: failed to open a file.");
- }
+inline void
+LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, std::vector <Scalar> &data) {
+ bool fortran_order;
+ LoadArrayFromNumpy<Scalar>(filename, shape, fortran_order, data);
+}
+
+template<typename Scalar>
+inline void LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, bool &fortran_order,
+ std::vector <Scalar> &data) {
+ std::ifstream stream(filename, std::ifstream::binary);
+ if (!stream) {
+ throw std::runtime_error("io error: failed to open a file.");
+ }
- std::string header = read_header(stream);
+ std::string header_s = read_header(stream);
- // parse header
- bool fortran_order;
- std::string typestr;
+ // parse header
+ header_t header = parse_header(header_s);
- parse_header(header, typestr, fortran_order, shape);
+ // check if the typestring matches the given one
+// static_assert(has_typestring<Scalar>::value, "scalar type not understood");
+ const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
- // check if the typestring matches the given one
- Typestring typestring_o {data};
- std::string expect_typestr = typestring_o.str();
- if (typestr != expect_typestr) {
- throw std::runtime_error("formatting error: typestrings not matching");
- }
+ if (header.dtype.tie() != dtype.tie()) {
+ throw std::runtime_error("formatting error: typestrings not matching");
+ }
+ shape = header.shape;
+ fortran_order = header.fortran_order;
- // compute the data size based on the shape
- auto size = static_cast<size_t>(comp_size(shape));
- data.resize(size);
+ // compute the data size based on the shape
+ auto size = static_cast<size_t>(comp_size(shape));
+ data.resize(size);
- // read the data
- stream.read(reinterpret_cast<char*>(data.data()), sizeof(Scalar)*size);
+ // read the data
+ stream.read(reinterpret_cast<char *>(data.data()), sizeof(Scalar) * size);
}
-} // namespace npy
+} // namespace npy
-#endif // NPY_H
+#endif // NPY_HPP_