diff options
author | Tai Ly <tai.ly@arm.com> | 2023-04-04 20:34:05 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-04-04 21:07:10 +0000 |
commit | 3ef34fb300e7f62bdb397c605ab6c3bd30682cf8 (patch) | |
tree | 0cd4e9c64fd393f6ba23b25f31b379dae7652049 | |
parent | dce6cebbeb6c45625c4ef8fafb5a7775319101c5 (diff) | |
download | serialization_lib-3ef34fb300e7f62bdb397c605ab6c3bd30682cf8.tar.gz |
Add readFromNpyFile and writeToNpyFile for double data
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: Icc023cbe6aa8843cc37d25e740bc6ce05bb7abb2
-rw-r--r-- | include/numpy_utils.h | 6 | ||||
-rw-r--r-- | src/numpy_utils.cpp | 20 | ||||
-rwxr-xr-x | test/scripts/test_npy_fileio.py | 2 | ||||
-rw-r--r-- | test/src/serialization_npy_test.cpp | 43 |
4 files changed, 70 insertions, 1 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h index 6a20eb3..8c2ed71 100644 --- a/include/numpy_utils.h +++ b/include/numpy_utils.h @@ -41,6 +41,8 @@ public: static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf); + static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf); + static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf); static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf); @@ -68,6 +70,10 @@ public: static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf); + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf); + private: static NPError writeToNpyFileCommon(const char* filename, const char* dtype_str, diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index c770d45..123908a 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -46,6 +46,12 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false); } +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf) +{ + const char dtype_str[] = "'<f8'"; + return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false); +} + NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf) { @@ -315,6 +321,20 @@ NumpyUtilities::NPError return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false); } +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf) +{ + std::vector<int32_t> shape = { (int32_t)elems }; + return writeToNpyFile(filename, shape, databuf); +} + +NumpyUtilities::NPError + NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf) +{ + const char dtype_str[] = "'<f8'"; + return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false); +} + NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const half_float::half* databuf) diff --git a/test/scripts/test_npy_fileio.py b/test/scripts/test_npy_fileio.py index e0a6f5d..272c124 100755 --- a/test/scripts/test_npy_fileio.py +++ b/test/scripts/test_npy_fileio.py @@ -122,7 +122,7 @@ def main(): xunit_suite = xunit_result.create_suite("basic_serialization") max_size = 128 - datatypes = ["int32", "int64", "float", "bool"] + datatypes = ["int32", "int64", "float", "bool", "double"] random.seed(args.seed) failed = 0 diff --git a/test/src/serialization_npy_test.cpp b/test/src/serialization_npy_test.cpp index 27ec464..64536fb 100644 --- a/test/src/serialization_npy_test.cpp +++ b/test/src/serialization_npy_test.cpp @@ -103,6 +103,45 @@ int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen, return 0; } +template <class T> +int test_double_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename) +{ + size_t total_size = 1; + std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max()); + + for (auto i : shape) + { + total_size *= i; + } + + auto buffer = std::make_unique<T[]>(total_size); + for (int i = 0; i < total_size; i++) + { + buffer[i] = gen_data(gen); + } + + NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error writing file, code " << err << std::endl; + return 1; + } + + auto read_buffer = std::make_unique<T[]>(total_size); + err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get()); + if (err != NumpyUtilities::NO_ERROR) + { + std::cout << "Error reading file, code " << err << std::endl; + return 1; + } + if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T))) + { + std::cout << "Miscompare" << std::endl; + return 1; + } + return 0; +} + int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename) { size_t total_size = 1; @@ -212,6 +251,10 @@ int main(int argc, char** argv) { return test_float_type<float>(shape, gen, filename); } + else if (str_type == "double") + { + return test_double_type<double>(shape, gen, filename); + } else if (str_type == "bool") { return test_bool_type(shape, gen, filename); |