aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-04-04 20:34:05 +0000
committerTai Ly <tai.ly@arm.com>2023-04-04 21:07:10 +0000
commit3ef34fb300e7f62bdb397c605ab6c3bd30682cf8 (patch)
tree0cd4e9c64fd393f6ba23b25f31b379dae7652049
parentdce6cebbeb6c45625c4ef8fafb5a7775319101c5 (diff)
downloadserialization_lib-3ef34fb300e7f62bdb397c605ab6c3bd30682cf8.tar.gz
Add readFromNpyFile and writeToNpyFile for double data
Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Icc023cbe6aa8843cc37d25e740bc6ce05bb7abb2
-rw-r--r--include/numpy_utils.h6
-rw-r--r--src/numpy_utils.cpp20
-rwxr-xr-xtest/scripts/test_npy_fileio.py2
-rw-r--r--test/src/serialization_npy_test.cpp43
4 files changed, 70 insertions, 1 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 6a20eb3..8c2ed71 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -41,6 +41,8 @@ public:
static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf);
+
static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf);
static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
@@ -68,6 +70,10 @@ public:
static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf);
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf);
+
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf);
+
private:
static NPError writeToNpyFileCommon(const char* filename,
const char* dtype_str,
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index c770d45..123908a 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -46,6 +46,12 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co
return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
}
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf)
+{
+ const char dtype_str[] = "'<f8'";
+ return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false);
+}
+
NumpyUtilities::NPError
NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
{
@@ -315,6 +321,20 @@ NumpyUtilities::NPError
return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
}
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf)
+{
+ std::vector<int32_t> shape = { (int32_t)elems };
+ return writeToNpyFile(filename, shape, databuf);
+}
+
+NumpyUtilities::NPError
+ NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf)
+{
+ const char dtype_str[] = "'<f8'";
+ return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false);
+}
+
NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
const std::vector<int32_t>& shape,
const half_float::half* databuf)
diff --git a/test/scripts/test_npy_fileio.py b/test/scripts/test_npy_fileio.py
index e0a6f5d..272c124 100755
--- a/test/scripts/test_npy_fileio.py
+++ b/test/scripts/test_npy_fileio.py
@@ -122,7 +122,7 @@ def main():
xunit_suite = xunit_result.create_suite("basic_serialization")
max_size = 128
- datatypes = ["int32", "int64", "float", "bool"]
+ datatypes = ["int32", "int64", "float", "bool", "double"]
random.seed(args.seed)
failed = 0
diff --git a/test/src/serialization_npy_test.cpp b/test/src/serialization_npy_test.cpp
index 27ec464..64536fb 100644
--- a/test/src/serialization_npy_test.cpp
+++ b/test/src/serialization_npy_test.cpp
@@ -103,6 +103,45 @@ int test_float_type(std::vector<int32_t> shape, std::default_random_engine& gen,
return 0;
}
+template <class T>
+int test_double_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
+{
+ size_t total_size = 1;
+ std::uniform_real_distribution<T> gen_data(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+
+ for (auto i : shape)
+ {
+ total_size *= i;
+ }
+
+ auto buffer = std::make_unique<T[]>(total_size);
+ for (int i = 0; i < total_size; i++)
+ {
+ buffer[i] = gen_data(gen);
+ }
+
+ NumpyUtilities::NPError err = NumpyUtilities::writeToNpyFile(filename.c_str(), shape, buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error writing file, code " << err << std::endl;
+ return 1;
+ }
+
+ auto read_buffer = std::make_unique<T[]>(total_size);
+ err = NumpyUtilities::readFromNpyFile(filename.c_str(), total_size, read_buffer.get());
+ if (err != NumpyUtilities::NO_ERROR)
+ {
+ std::cout << "Error reading file, code " << err << std::endl;
+ return 1;
+ }
+ if (memcmp(buffer.get(), read_buffer.get(), total_size * sizeof(T)))
+ {
+ std::cout << "Miscompare" << std::endl;
+ return 1;
+ }
+ return 0;
+}
+
int test_bool_type(std::vector<int32_t> shape, std::default_random_engine& gen, std::string& filename)
{
size_t total_size = 1;
@@ -212,6 +251,10 @@ int main(int argc, char** argv)
{
return test_float_type<float>(shape, gen, filename);
}
+ else if (str_type == "double")
+ {
+ return test_double_type<double>(shape, gen, filename);
+ }
else if (str_type == "bool")
{
return test_bool_type(shape, gen, filename);