diff options
Diffstat (limited to 'test/src/serialization_npy_test.cpp')
-rw-r--r-- | test/src/serialization_npy_test.cpp | 61 |
1 files changed, 49 insertions, 12 deletions
diff --git a/test/src/serialization_npy_test.cpp b/test/src/serialization_npy_test.cpp index 27ec464..24e3aff 100644 --- a/test/src/serialization_npy_test.cpp +++ b/test/src/serialization_npy_test.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2021, ARM Limited. +// Copyright (c) 2021,2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ int test_int_type(std::vector<int32_t> shape, std::default_random_engine& gen, s } auto buffer = std::make_unique<T[]>(total_size); - for (int i = 0; i < total_size; i++) + for (size_t i = 0; i < total_size; i++) { buffer[i] = gen_data(gen); } @@ -76,7 +76,46 @@ int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen, } auto buffer = std::make_unique<T[]>(total_size); - for (int i = 0; i < total_size; i++) + for (size_t i = 0; i < total_size; i++) + { + buffer[i] = gen_data(gen); + } + + NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error writing file, code " << err << std::endl; + return 1; + } + + auto read_buffer = std::make_unique<T[]>(total_size); + err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error reading file, code " << err << std::endl; + return 1; + } + if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T))) + { + std::cout << "Miscompare" << std::endl; + return 1; + } + return 0; +} + +template <class T> +int test_double_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename) +{ + size_t total_size = 1; + std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max()); + + for (auto i : shape) + { + total_size *= i; + } + + auto buffer = std::make_unique<T[]>(total_size); + for (size_t i = 0; i < total_size; i++) { buffer[i] = gen_data(gen); } @@ -114,7 +153,7 @@ int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, } auto buffer = std::make_unique<bool[]>(total_size); - for (int i = 0; i < total_size; i++) + for (size_t i = 0; i < total_size; i++) { buffer[i] = (gen_data(gen)) ? true : false; } @@ -144,15 +183,13 @@ int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, int main(int argc, char** argv) { - size_t total_size = 1; - int32_t seed = 1; + int32_t seed = 1; std::string str_type; std::string str_shape; std::string filename = "npytest.npy"; std::vector<int32_t> shape; - bool verbose = false; int opt; - while ((opt = getopt(argc, argv, "d:f:s:t:v")) != -1) + while ((opt = getopt(argc, argv, "d:f:s:t:")) != -1) { switch (opt) { @@ -168,9 +205,6 @@ int main(int argc, char** argv) case 't': str_shape = optarg; break; - case 'v': - verbose = true; - break; default: std::cerr << "Invalid argument" << std::endl; break; @@ -193,7 +227,6 @@ int main(int argc, char** argv) break; int val = stoi(substr, &pos, 0); assert(val); - total_size *= val; shape.push_back(val); } @@ -212,6 +245,10 @@ int main(int argc, char** argv) { return test_float_type<float>(shape, gen, filename); } + else if (str_type == "double") + { + return test_double_type<double>(shape, gen, filename); + } else if (str_type == "bool") { return test_bool_type(shape, gen, filename); |