aboutsummaryrefslogtreecommitdiff
path: root/test/src/serialization_npy_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/src/serialization_npy_test.cpp')
-rw-r--r--test/src/serialization_npy_test.cpp225
1 files changed, 225 insertions, 0 deletions
diff --git a/test/src/serialization_npy_test.cpp b/test/src/serialization_npy_test.cpp
new file mode 100644
index 0000000..27ec464
--- /dev/null
+++ b/test/src/serialization_npy_test.cpp
@@ -0,0 +1,225 @@
+// Copyright (c) 2021, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <getopt.h>
+#include <iostream>
+#include <random>
+#include <sstream>
+#include <tosa_serialization_handler.h>
+
+using namespace tosa;
+
+void usage()
+{
+ std::cout << "Usage: serialization_npy_test -f <filename> -t <shape> -d <datatype> -s <seed>" << std::endl;
+}
+
+template <class T>
+int test_int_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
+{
+ size_t total_size = 1;
+ std::uniform_int_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+
+ for (auto i : shape)
+ {
+ total_size *= i;
+ }
+
+ auto buffer = std::make_unique<T[]>(total_size);
+ for (int i = 0; i < total_size; i++)
+ {
+ buffer[i] = gen_data(gen);
+ }
+
+ NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error writing file, code " << err << std::endl;
+ return 1;
+ }
+
+ auto read_buffer = std::make_unique<T[]>(total_size);
+ err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error reading file, code " << err << std::endl;
+ return 1;
+ }
+ if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
+ {
+ std::cout << "Miscompare" << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+template <class T>
+int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
+{
+ size_t total_size = 1;
+ std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+
+ for (auto i : shape)
+ {
+ total_size *= i;
+ }
+
+ auto buffer = std::make_unique<T[]>(total_size);
+ for (int i = 0; i < total_size; i++)
+ {
+ buffer[i] = gen_data(gen);
+ }
+
+ NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error writing file, code " << err << std::endl;
+ return 1;
+ }
+
+ auto read_buffer = std::make_unique<T[]>(total_size);
+ err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error reading file, code " << err << std::endl;
+ return 1;
+ }
+ if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
+ {
+ std::cout << "Miscompare" << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
+{
+ size_t total_size = 1;
+ std::uniform_int_distribution<uint32_t> gen_data(0, 1);
+
+ for (auto i : shape)
+ {
+ total_size *= i;
+ }
+
+ auto buffer = std::make_unique<bool[]>(total_size);
+ for (int i = 0; i < total_size; i++)
+ {
+ buffer[i] = (gen_data(gen)) ? true : false;
+ }
+
+ NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error writing file, code " << err << std::endl;
+ return 1;
+ }
+
+ auto read_buffer = std::make_unique<bool[]>(total_size);
+ err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error reading file, code " << err << std::endl;
+ return 1;
+ }
+
+ if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(bool)))
+ {
+ std::cout << "Miscompare" << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
+int main(int argc, char** argv)
+{
+ size_t total_size = 1;
+ int32_t seed = 1;
+ std::string str_type;
+ std::string str_shape;
+ std::string filename = "npytest.npy";
+ std::vector<int32_t> shape;
+ bool verbose = false;
+ int opt;
+ while ((opt = getopt(argc, argv, "d:f:s:t:v")) != -1)
+ {
+ switch (opt)
+ {
+ case 'd':
+ str_type = optarg;
+ break;
+ case 'f':
+ filename = optarg;
+ break;
+ case 's':
+ seed = strtol(optarg, nullptr, 0);
+ break;
+ case 't':
+ str_shape = optarg;
+ break;
+ case 'v':
+ verbose = true;
+ break;
+ default:
+ std::cerr << "Invalid argument" << std::endl;
+ break;
+ }
+ }
+ if (str_shape == "")
+ {
+ usage();
+ return 1;
+ }
+
+ // parse shape from argument
+ std::stringstream ss(str_shape);
+ while (ss.good())
+ {
+ std::string substr;
+ size_t pos;
+ std::getline(ss, substr, ',');
+ if (substr == "")
+ break;
+ int val = stoi(substr, &pos, 0);
+ assert(val);
+ total_size *= val;
+ shape.push_back(val);
+ }
+
+ std::default_random_engine gen(seed);
+
+ // run with type from argument
+ if (str_type == "int32")
+ {
+ return test_int_type<int32_t>(shape, gen, filename);
+ }
+ else if (str_type == "int64")
+ {
+ return test_int_type<int64_t>(shape, gen, filename);
+ }
+ else if (str_type == "float")
+ {
+ return test_float_type<float>(shape, gen, filename);
+ }
+ else if (str_type == "bool")
+ {
+ return test_bool_type(shape, gen, filename);
+ }
+ else
+ {
+ std::cout << "Unknown type " << str_type << std::endl;
+ usage();
+ return 1;
+ }
+}