aboutsummaryrefslogtreecommitdiff
path: root/test/src/serialization_npy_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/src/serialization_npy_test.cpp')
-rw-r--r--test/src/serialization_npy_test.cpp43
1 files changed, 43 insertions, 0 deletions
diff --git a/test/src/serialization_npy_test.cpp b/test/src/serialization_npy_test.cpp
index 27ec464..64536fb 100644
--- a/test/src/serialization_npy_test.cpp
+++ b/test/src/serialization_npy_test.cpp
@@ -103,6 +103,45 @@ int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen,
return 0;
}
+template <class T>
+int test_double_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
+{
+ size_t total_size = 1;
+ std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+
+ for (auto i : shape)
+ {
+ total_size *= i;
+ }
+
+ auto buffer = std::make_unique<T[]>(total_size);
+ for (int i = 0; i < total_size; i++)
+ {
+ buffer[i] = gen_data(gen);
+ }
+
+ NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error writing file, code " << err << std::endl;
+ return 1;
+ }
+
+ auto read_buffer = std::make_unique<T[]>(total_size);
+ err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error reading file, code " << err << std::endl;
+ return 1;
+ }
+ if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
+ {
+ std::cout << "Miscompare" << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
{
size_t total_size = 1;
@@ -212,6 +251,10 @@ int main(int argc, char** argv)
{
return test_float_type<float>(shape, gen, filename);
}
+ else if (str_type == "double")
+ {
+ return test_double_type<double>(shape, gen, filename);
+ }
else if (str_type == "bool")
{
return test_bool_type(shape, gen, filename);