aboutsummaryrefslogtreecommitdiff
path: root/test/src/serialization_npy_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/src/serialization_npy_test.cpp')
-rw-r--r--test/src/serialization_npy_test.cpp61
1 files changed, 49 insertions, 12 deletions
diff --git a/test/src/serialization_npy_test.cpp b/test/src/serialization_npy_test.cpp
index 27ec464..24e3aff 100644
--- a/test/src/serialization_npy_test.cpp
+++ b/test/src/serialization_npy_test.cpp
@@ -1,4 +1,4 @@
-// Copyright (c) 2021, ARM Limited.
+// Copyright (c) 2021,2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@ int test_int_type(std::vector<int32_t> shape, std::default_random_engine& gen, s
}
auto buffer = std::make_unique<T[]>(total_size);
- for (int i = 0; i < total_size; i++)
+ for (size_t i = 0; i < total_size; i++)
{
buffer[i] = gen_data(gen);
}
@@ -76,7 +76,46 @@ int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen,
}
auto buffer = std::make_unique<T[]>(total_size);
- for (int i = 0; i < total_size; i++)
+ for (size_t i = 0; i < total_size; i++)
+ {
+ buffer[i] = gen_data(gen);
+ }
+
+ NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error writing file, code " << err << std::endl;
+ return 1;
+ }
+
+ auto read_buffer = std::make_unique<T[]>(total_size);
+ err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error reading file, code " << err << std::endl;
+ return 1;
+ }
+ if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
+ {
+ std::cout << "Miscompare" << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+template <class T>
+int test_double_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
+{
+ size_t total_size = 1;
+ std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+
+ for (auto i : shape)
+ {
+ total_size *= i;
+ }
+
+ auto buffer = std::make_unique<T[]>(total_size);
+ for (size_t i = 0; i < total_size; i++)
{
buffer[i] = gen_data(gen);
}
@@ -114,7 +153,7 @@ int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen,
}
auto buffer = std::make_unique<bool[]>(total_size);
- for (int i = 0; i < total_size; i++)
+ for (size_t i = 0; i < total_size; i++)
{
buffer[i] = (gen_data(gen)) ? true : false;
}
@@ -144,15 +183,13 @@ int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen,
int main(int argc, char** argv)
{
- size_t total_size = 1;
- int32_t seed = 1;
+ int32_t seed = 1;
std::string str_type;
std::string str_shape;
std::string filename = "npytest.npy";
std::vector<int32_t> shape;
- bool verbose = false;
int opt;
- while ((opt = getopt(argc, argv, "d:f:s:t:v")) != -1)
+ while ((opt = getopt(argc, argv, "d:f:s:t:")) != -1)
{
switch (opt)
{
@@ -168,9 +205,6 @@ int main(int argc, char** argv)
case 't':
str_shape = optarg;
break;
- case 'v':
- verbose = true;
- break;
default:
std::cerr << "Invalid argument" << std::endl;
break;
@@ -193,7 +227,6 @@ int main(int argc, char** argv)
break;
int val = stoi(substr, &pos, 0);
assert(val);
- total_size *= val;
shape.push_back(val);
}
@@ -212,6 +245,10 @@ int main(int argc, char** argv)
{
return test_float_type<float>(shape, gen, filename);
}
+ else if (str_type == "double")
+ {
+ return test_double_type<double>(shape, gen, filename);
+ }
else if (str_type == "bool")
{
return test_bool_type(shape, gen, filename);