aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorAnthony Barbier <anthony.barbier@arm.com>2017-11-10 16:27:32 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit87f21cd39639118293b51fa776d30aa7722917bd (patch)
tree9a29853b9dd1d8b4dd289334bcab14fc1bc09725 /include
parent5edbd1c5dce43b66f30c903797a91e39369c5b62 (diff)
downloadComputeLibrary-87f21cd39639118293b51fa776d30aa7722917bd.tar.gz
COMPMID-556 Updated libnpy.hpp
Change-Id: I380a11f41ca2158de1dd0a6339ed9c884feb8f69 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95385 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'include')
-rw-r--r--include/libnpy/npy.hpp506
1 files changed, 293 insertions, 213 deletions
diff --git a/include/libnpy/npy.hpp b/include/libnpy/npy.hpp
index 9b6f7fb7ba..24244ca272 100644
--- a/include/libnpy/npy.hpp
+++ b/include/libnpy/npy.hpp
@@ -20,16 +20,21 @@
SOFTWARE.
*/
+#ifndef NPY_H
+#define NPY_H
+
#include <complex>
#include <fstream>
#include <string>
#include <iostream>
#include <sstream>
#include <cstdint>
+#include <cstring>
#include <vector>
#include <stdexcept>
#include <algorithm>
#include <regex>
+#include <unordered_map>
namespace npy {
@@ -61,29 +66,28 @@ constexpr char host_endian_char = ( big_endian ?
big_endian_char :
little_endian_char );
+/* npy array length */
+typedef unsigned long int ndarray_len_t;
+
inline void write_magic(std::ostream& ostream, unsigned char v_major=1, unsigned char v_minor=0) {
ostream.write(magic_string, magic_string_length);
ostream.put(v_major);
ostream.put(v_minor);
}
-inline void read_magic(std::istream& istream, unsigned char *v_major, unsigned char *v_minor) {
- char *buf = new char[magic_string_length+2];
+inline void read_magic(std::istream& istream, unsigned char& v_major, unsigned char& v_minor) {
+ char buf[magic_string_length+2];
istream.read(buf, magic_string_length+2);
if(!istream) {
- throw std::runtime_error("io error: failed reading file");
+ throw std::runtime_error("io error: failed reading file");
}
- for (size_t i=0; i < magic_string_length; i++) {
- if(buf[i] != magic_string[i]) {
- throw std::runtime_error("this file do not have a valid npy format.");
- }
- }
+ if (0 != std::memcmp(buf, magic_string, magic_string_length))
+ throw std::runtime_error("this file does not have a valid npy format.");
- *v_major = buf[magic_string_length];
- *v_minor = buf[magic_string_length+1];
- delete[] buf;
+ v_major = buf[magic_string_length];
+ v_minor = buf[magic_string_length+1];
}
// typestring magic
@@ -101,25 +105,40 @@ struct Typestring {
return std::string(buf);
}
- Typestring(std::vector<float>& v) :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(float)} {}
- Typestring(std::vector<double>& v) :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(double)} {}
- Typestring(std::vector<long double>& v) :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(long double)} {}
-
- Typestring(std::vector<char>& v) :c_endian {no_endian_char}, c_type {'i'}, len {sizeof(char)} {}
- Typestring(std::vector<short>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(short)} {}
- Typestring(std::vector<int>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(int)} {}
- Typestring(std::vector<long>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long)} {}
- Typestring(std::vector<long long>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long long)} {}
-
- Typestring(std::vector<unsigned char>& v) :c_endian {no_endian_char}, c_type {'u'}, len {sizeof(unsigned char)} {}
- Typestring(std::vector<unsigned short>& v) :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned short)} {}
- Typestring(std::vector<unsigned int>& v) :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned int)} {}
- Typestring(std::vector<unsigned long>& v) :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long)} {}
- Typestring(std::vector<unsigned long long>& v) :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long long)} {}
-
- Typestring(std::vector<std::complex<float>>& v) :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<float>)} {}
- Typestring(std::vector<std::complex<double>>& v) :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<double>)} {}
- Typestring(std::vector<std::complex<long double>>& v) :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<long double>)} {}
+ Typestring(const std::vector<float>& v)
+ :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(float)} {}
+ Typestring(const std::vector<double>& v)
+ :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(double)} {}
+ Typestring(const std::vector<long double>& v)
+ :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(long double)} {}
+
+ Typestring(const std::vector<char>& v)
+ :c_endian {no_endian_char}, c_type {'i'}, len {sizeof(char)} {}
+ Typestring(const std::vector<short>& v)
+ :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(short)} {}
+ Typestring(const std::vector<int>& v)
+ :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(int)} {}
+ Typestring(const std::vector<long>& v)
+ :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long)} {}
+ Typestring(const std::vector<long long>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long long)} {}
+
+ Typestring(const std::vector<unsigned char>& v)
+ :c_endian {no_endian_char}, c_type {'u'}, len {sizeof(unsigned char)} {}
+ Typestring(const std::vector<unsigned short>& v)
+ :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned short)} {}
+ Typestring(const std::vector<unsigned int>& v)
+ :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned int)} {}
+ Typestring(const std::vector<unsigned long>& v)
+ :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long)} {}
+ Typestring(const std::vector<unsigned long long>& v)
+ :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long long)} {}
+
+ Typestring(const std::vector<std::complex<float>>& v)
+ :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<float>)} {}
+ Typestring(const std::vector<std::complex<double>>& v)
+ :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<double>)} {}
+ Typestring(const std::vector<std::complex<long double>>& v)
+ :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<long double>)} {}
};
inline void parse_typestring( std::string typestring){
@@ -133,39 +152,174 @@ inline void parse_typestring( std::string typestring){
}
}
-/* Helpers for the improvised parser */
-inline std::string unwrap_s(std::string s, char delim_front, char delim_back) {
- if ((s.back() == delim_back) && (s.front() == delim_front))
- return s.substr(1, s.length()-2);
- else
- throw std::runtime_error("unable to unwrap");
+namespace pyparse {
+
+/**
+ Removes leading and trailing whitespaces
+ */
+inline std::string trim(const std::string& str) {
+ const std::string whitespace = " \t";
+ auto begin = str.find_first_not_of(whitespace);
+
+ if (begin == std::string::npos)
+ return "";
+
+ auto end = str.find_last_not_of(whitespace);
+
+ return str.substr(begin, end-begin+1);
}
-inline std::string get_value_from_map(std::string mapstr) {
+
+inline std::string get_value_from_map(const std::string& mapstr) {
size_t sep_pos = mapstr.find_first_of(":");
if (sep_pos == std::string::npos)
return "";
- return mapstr.substr(sep_pos+1);
+ std::string tmp = mapstr.substr(sep_pos+1);
+ return trim(tmp);
+}
+
+/**
+ Parses the string representation of a Python dict
+
+ The keys need to be known and may not appear anywhere else in the data.
+ */
+inline std::unordered_map<std::string, std::string> parse_dict(std::string in, std::vector<std::string>& keys) {
+
+ std::unordered_map<std::string, std::string> map;
+
+ if (keys.size() == 0)
+ return map;
+
+ in = trim(in);
+
+ // unwrap dictionary
+ if ((in.front() == '{') && (in.back() == '}'))
+ in = in.substr(1, in.length()-2);
+ else
+ throw std::runtime_error("Not a Python dictionary.");
+
+ std::vector<std::pair<size_t, std::string>> positions;
+
+ for (auto const& value : keys) {
+ size_t pos = in.find( "'" + value + "'" );
+
+ if (pos == std::string::npos)
+ throw std::runtime_error("Missing '"+value+"' key.");
+
+ std::pair<size_t, std::string> position_pair { pos, value };
+ positions.push_back(position_pair);
+ }
+
+ // sort by position in dict
+ std::sort(positions.begin(), positions.end() );
+
+ for(size_t i = 0; i < positions.size(); ++i) {
+ std::string raw_value;
+ size_t begin { positions[i].first };
+ size_t end { std::string::npos };
+
+ std::string key = positions[i].second;
+
+ if ( i+1 < positions.size() )
+ end = positions[i+1].first;
+
+ raw_value = in.substr(begin, end-begin);
+
+ raw_value = trim(raw_value);
+
+ if (raw_value.back() == ',')
+ raw_value.pop_back();
+
+ map[key] = get_value_from_map(raw_value);
+ }
+
+ return map;
}
-inline void pop_char(std::string& s, char c) {
- if (s.back() == c)
- s.pop_back();
+/**
+ Parses the string representation of a Python boolean
+ */
+inline bool parse_bool(const std::string& in) {
+ if (in == "True")
+ return true;
+ if (in == "False")
+ return false;
+
+ throw std::runtime_error("Invalid python boolan.");
}
-inline void ParseHeader(std::string header, std::string& descr, bool *fortran_order, std::vector<unsigned long>& shape) {
+/**
+ Parses the string representation of a Python str
+ */
+inline std::string parse_str(const std::string& in) {
+ if ((in.front() == '\'') && (in.back() == '\''))
+ return in.substr(1, in.length()-2);
+
+ throw std::runtime_error("Invalid python string.");
+}
+
+/**
+ Parses the string represenatation of a Python tuple into a vector of its items
+ */
+inline std::vector<std::string> parse_tuple(std::string in) {
+ std::vector<std::string> v;
+ const char seperator = ',';
+
+ in = trim(in);
+
+ if ((in.front() == '(') && (in.back() == ')'))
+ in = in.substr(1, in.length()-2);
+ else
+ throw std::runtime_error("Invalid Python tuple.");
+
+ std::istringstream iss(in);
+
+ for (std::string token; std::getline(iss, token, seperator);) {
+ v.push_back(token);
+ }
+
+ return v;
+}
+
+template <typename T>
+inline std::string write_tuple(const std::vector<T>& v) {
+ if (v.size() == 0)
+ return "";
+
+ std::ostringstream ss;
+
+ if (v.size() == 1) {
+ ss << "(" << v.front() << ",)";
+ } else {
+ const std::string delimiter = ", ";
+ // v.size() > 1
+ ss << "(";
+ std::copy(v.begin(), v.end()-1, std::ostream_iterator<T>(ss, delimiter.c_str()));
+ ss << v.back();
+ ss << ")";
+ }
+
+ return ss.str();
+}
+
+inline std::string write_boolean(bool b) {
+ if(b)
+ return "True";
+ else
+ return "False";
+}
+
+} // namespace pyparse
+
+
+inline void parse_header(std::string header, std::string& descr, bool& fortran_order, std::vector<ndarray_len_t>& shape) {
/*
The first 6 bytes are a magic string: exactly "x93NUMPY".
-
The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
-
The next 1 byte is an unsigned byte: the minor version number of the file format, e.g. x00. Note: the version of the file format is not tied to the version of the numpy package.
-
The next 2 bytes form a little-endian unsigned short int: the length of the header data HEADER_LEN.
-
The next HEADER_LEN bytes form the header data describing the array's format. It is an ASCII string which contains a Python literal expression of a dictionary. It is terminated by a newline ('n') and padded with spaces ('x20') to make the total length of the magic string + 4 + HEADER_LEN be evenly divisible by 16 for alignment purposes.
-
The dictionary contains three keys:
"descr" : dtype.descr
@@ -182,128 +336,61 @@ inline void ParseHeader(std::string header, std::string& descr, bool *fortran_or
throw std::runtime_error("invalid header");
header.pop_back();
- // remove all whitespaces
- header.erase(std::remove(header.begin(), header.end(), ' '), header.end());
+ // parse the dictionary
+ std::vector<std::string> keys { "descr", "fortran_order", "shape" };
+ auto dict_map = npy::pyparse::parse_dict(header, keys);
- // unwrap dictionary
- header = unwrap_s(header, '{', '}');
-
- // find the positions of the 3 dictionary keys
- size_t keypos_descr = header.find("'descr'");
- size_t keypos_fortran = header.find("'fortran_order'");
- size_t keypos_shape = header.find("'shape'");
-
- // make sure all the keys are present
- if (keypos_descr == std::string::npos)
- throw std::runtime_error("missing 'descr' key");
- if (keypos_fortran == std::string::npos)
- throw std::runtime_error("missing 'fortran_order' key");
- if (keypos_shape == std::string::npos)
- throw std::runtime_error("missing 'shape' key");
-
- // Make sure the keys are in order.
- // Note that this violates the standard, which states that readers *must* not
- // depend on the correct order here.
- // TODO: fix
- if (keypos_descr >= keypos_fortran || keypos_fortran >= keypos_shape)
- throw std::runtime_error("header keys in wrong order");
-
- // get the 3 key-value pairs
- std::string keyvalue_descr;
- keyvalue_descr = header.substr(keypos_descr, keypos_fortran - keypos_descr);
- pop_char(keyvalue_descr, ',');
-
- std::string keyvalue_fortran;
- keyvalue_fortran = header.substr(keypos_fortran, keypos_shape - keypos_fortran);
- pop_char(keyvalue_fortran, ',');
-
- std::string keyvalue_shape;
- keyvalue_shape = header.substr(keypos_shape, std::string::npos);
- pop_char(keyvalue_shape, ',');
-
- // get the values (right side of `:')
- std::string descr_s = get_value_from_map(keyvalue_descr);
- std::string fortran_s = get_value_from_map(keyvalue_fortran);
- std::string shape_s = get_value_from_map(keyvalue_shape);
+ if (dict_map.size() == 0)
+ throw std::runtime_error("invalid dictionary in header");
+
+ std::string descr_s = dict_map["descr"];
+ std::string fortran_s = dict_map["fortran_order"];
+ std::string shape_s = dict_map["shape"];
+ // TODO: extract info from typestring
parse_typestring(descr_s);
- descr = unwrap_s(descr_s, '\'', '\'');
+ // remove
+ descr = npy::pyparse::parse_str(descr_s);
// convert literal Python bool to C++ bool
- if (fortran_s == "True")
- *fortran_order = true;
- else if (fortran_s == "False")
- *fortran_order = false;
- else
- throw std::runtime_error("invalid fortran_order value");
-
- // parse the shape Python tuple ( x, y, z,)
-
- // first clear the vector
- shape.clear();
- shape_s = unwrap_s(shape_s, '(', ')');
-
- // a tokenizer would be nice...
- size_t pos = 0;
- size_t pos_next;
- for(;;) {
- pos_next = shape_s.find_first_of(',', pos);
- std::string dim_s;
- if (pos_next != std::string::npos)
- dim_s = shape_s.substr(pos, pos_next - pos);
- else
- dim_s = shape_s.substr(pos);
- pop_char(dim_s, ',');
- if (dim_s.length() == 0) {
- if (pos_next != std::string::npos)
- throw std::runtime_error("invalid shape");
- }else{
- std::stringstream ss;
- ss << dim_s;
- unsigned long tmp;
- ss >> tmp;
- shape.push_back(tmp);
- }
- if (pos_next != std::string::npos)
- pos = ++pos_next;
- else
- break;
+ fortran_order = npy::pyparse::parse_bool(fortran_s);
+
+ // parse the shape tuple
+ auto shape_v = npy::pyparse::parse_tuple(shape_s);
+ if (shape_v.size() == 0)
+ throw std::runtime_error("invalid shape tuple in header");
+
+ for ( auto item : shape_v ) {
+ std::stringstream stream(item);
+ unsigned long value;
+ stream >> value;
+ ndarray_len_t dim = static_cast<ndarray_len_t>(value);
+ shape.push_back(dim);
}
}
-inline void WriteHeader(std::ostream& out, const std::string& descr, bool fortran_order, unsigned int n_dims, const unsigned long shape[])
-{
- std::ostringstream ss_header;
- std::string s_fortran_order;
- if (fortran_order)
- s_fortran_order = "True";
- else
- s_fortran_order = "False";
-
- std::ostringstream ss_shape;
- ss_shape << "(";
- for (unsigned int n=0; n < n_dims; n++){
- ss_shape << shape[n] << ", ";
- }
- ss_shape << ")";
- ss_header << "{'descr': '" << descr << "', 'fortran_order': " << s_fortran_order << ", 'shape': " << ss_shape.str() << " }";
+inline std::string write_header_dict(const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape) {
+ std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
+ std::string shape_s = npy::pyparse::write_tuple(shape);
+
+ return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
+}
+
+inline void write_header(std::ostream& out, const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape_v)
+{
+ std::string header_dict = write_header_dict(descr, fortran_order, shape_v);
- size_t header_len_pre = ss_header.str().length() + 1;
- size_t metadata_len = magic_string_length + 2 + 2 + header_len_pre;
+ size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
unsigned char version[2] = {1, 0};
- if (metadata_len >= 255*255) {
- metadata_len = magic_string_length + 2 + 4 + header_len_pre;
+ if (length >= 255*255) {
+ length = magic_string_length + 2 + 4 + header_dict.length() + 1;
version[0] = 2;
version[1] = 0;
}
- size_t padding_len = 16 - metadata_len % 16;
+ size_t padding_len = 16 - length % 16;
std::string padding (padding_len, ' ');
- ss_header << padding;
- ss_header << '\n';
-
- std::string header = ss_header.str();
// write magic
write_magic(out, version[0], version[1]);
@@ -311,14 +398,14 @@ inline void WriteHeader(std::ostream& out, const std::string& descr, bool fortra
// write header length
if (version[0] == 1 && version[1] == 0) {
char header_len_le16[2];
- uint16_t header_len = header.length();
+ uint16_t header_len = header_dict.length() + padding.length() + 1;
header_len_le16[0] = (header_len >> 0) & 0xff;
header_len_le16[1] = (header_len >> 8) & 0xff;
out.write(reinterpret_cast<char *>(header_len_le16), 2);
}else{
char header_len_le32[4];
- uint32_t header_len = header.length();
+ uint32_t header_len = header_dict.length() + padding.length() + 1;
header_len_le32[0] = (header_len >> 0) & 0xff;
header_len_le32[1] = (header_len >> 8) & 0xff;
@@ -327,96 +414,89 @@ inline void WriteHeader(std::ostream& out, const std::string& descr, bool fortra
out.write(reinterpret_cast<char *>(header_len_le32), 4);
}
- out << header;
+ out << header_dict << padding << '\n';
}
-inline std::string read_header_1_0(std::istream& istream) {
- // read header length and convert from little endian
- char header_len_le16[2];
- istream.read(header_len_le16, 2);
+inline std::string read_header(std::istream& istream) {
+ // check magic bytes an version number
+ unsigned char v_major, v_minor;
+ read_magic(istream, v_major, v_minor);
- uint16_t header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
+ uint32_t header_length;
+ if(v_major == 1 && v_minor == 0){
+
+ char header_len_le16[2];
+ istream.read(header_len_le16, 2);
+ header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
- if((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
+ if((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
+ // TODO: display warning
+ }
+ }else if(v_major == 2 && v_minor == 0) {
+ char header_len_le32[4];
+ istream.read(header_len_le32, 4);
+
+ header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
+ | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
+
+ if((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
// TODO: display warning
+ }
+ }else{
+ throw std::runtime_error("unsupported file format version");
}
- char *buf = new char[header_length];
- istream.read(buf, header_length);
- std::string header (buf, header_length);
- delete[] buf;
+ auto buf_v = std::vector<char>();
+ buf_v.reserve(header_length);
+ istream.read(buf_v.data(), header_length);
+ std::string header(buf_v.data(), header_length);
return header;
}
-inline std::string read_header_2_0(std::istream& istream) {
- // read header length and convert from little endian
- char header_len_le32[4];
- istream.read(header_len_le32, 4);
-
- uint32_t header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
- | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
+inline ndarray_len_t comp_size(const std::vector<ndarray_len_t>& shape) {
+ ndarray_len_t size = 1;
+ for (ndarray_len_t i : shape )
+ size *= i;
- if((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
- // TODO: display warning
- }
-
- char *buf = new char[header_length];
- istream.read(buf, header_length);
- std::string header (buf, header_length);
- delete[] buf;
-
- return header;
+ return size;
}
template<typename Scalar>
-void SaveArrayAsNumpy( const std::string& filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[], const std::vector<Scalar>& data)
+inline void SaveArrayAsNumpy( const std::string& filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[], const std::vector<Scalar>& data)
{
- Typestring typestring_o {data};
+ Typestring typestring_o(data);
std::string typestring = typestring_o.str();
std::ofstream stream( filename, std::ofstream::binary);
if(!stream) {
throw std::runtime_error("io error: failed to open a file.");
}
- WriteHeader(stream, typestring, fortran_order, n_dims, shape);
- size_t size = 1;
- for (unsigned int i=0; i<n_dims; ++i)
- size *= shape[i];
- stream.write(reinterpret_cast<const char*>(&data[0]), sizeof(Scalar) * size);
-}
+ std::vector<ndarray_len_t> shape_v(shape, shape+n_dims);
+ write_header(stream, typestring, fortran_order, shape_v);
+ auto size = static_cast<size_t>(comp_size(shape_v));
+
+ stream.write(reinterpret_cast<const char*>(data.data()), sizeof(Scalar) * size);
+}
-/**
- */
template<typename Scalar>
-void LoadArrayFromNumpy(const std::string& filename, std::vector<unsigned long>& shape, std::vector<Scalar>& data)
+inline void LoadArrayFromNumpy(const std::string& filename, std::vector<unsigned long>& shape, std::vector<Scalar>& data)
{
std::ifstream stream(filename, std::ifstream::binary);
if(!stream) {
throw std::runtime_error("io error: failed to open a file.");
}
- // check magic bytes an version number
- unsigned char v_major, v_minor;
- read_magic(stream, &v_major, &v_minor);
-
- std::string header;
- if(v_major == 1 && v_minor == 0){
- header = read_header_1_0(stream);
- }else if(v_major == 2 && v_minor == 0) {
- header = read_header_2_0(stream);
- }else{
- throw std::runtime_error("unsupported file format version");
- }
+ std::string header = read_header(stream);
// parse header
bool fortran_order;
std::string typestr;
- ParseHeader(header, typestr, &fortran_order, shape);
+ parse_header(header, typestr, fortran_order, shape);
// check if the typestring matches the given one
Typestring typestring_o {data};
@@ -425,15 +505,15 @@ void LoadArrayFromNumpy(const std::string& filename, std::vector<unsigned long>&
throw std::runtime_error("formatting error: typestrings not matching");
}
+
// compute the data size based on the shape
- size_t total_size = 1;
- for(size_t i=0; i<shape.size(); ++i) {
- total_size *= shape[i];
- }
- data.resize(total_size);
+ auto size = static_cast<size_t>(comp_size(shape));
+ data.resize(size);
// read the data
- stream.read(reinterpret_cast<char*>(&data[0]), sizeof(Scalar)*total_size);
+ stream.read(reinterpret_cast<char*>(data.data()), sizeof(Scalar)*size);
}
} // namespace npy
+
+#endif // NPY_H