aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2023-07-31 15:15:12 -0700
committerEric Kunze <eric.kunze@arm.com>2023-08-02 16:59:13 +0000
commit679bdadb6b51b14013a00588cec2452d6ee1d1ac (patch)
tree34f71ccf0db10fbea332ae1f5972eabb036422b6 /src
parente2b20e4eb9e91d9fd4a155880f3bf6085b8ffaac (diff)
downloadserialization_lib-679bdadb6b51b14013a00588cec2452d6ee1d1ac.tar.gz
Simplify overloaded writeToNpyFiles and readFromNpyFiles
templatize these functions instead to reduce redundant code. Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: Ie8b6f7d2b489c3508fea72481ce38f0db6d0c490
Diffstat (limited to 'src')
-rw-r--r--src/numpy_utils.cpp259
1 files changed, 30 insertions, 229 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index 0002fd9..64460bd 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -16,6 +16,7 @@
#include "numpy_utils.h"
#include "half.hpp"
#include <algorithm>
+
// Magic NUMPY header
static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
static const int NUMPY_HEADER_SZ = 128;
@@ -24,20 +25,10 @@ static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
// Offset for NUMPY header desc dictionary string
static const int NUMPY_HEADER_DESC_OFFSET = 8;
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
-{
- const char dtype_str[] = "'|b1'";
- return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
-}
-
+// This is an entry function for reading 8-/16-/32-bit npy file.
+template <>
NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
{
- const char dtype_str_uint8[] = "'|u1'";
- const char dtype_str_int8[] = "'|i1'";
- const char dtype_str_uint16[] = "'<u2'";
- const char dtype_str_int16[] = "'<i2'";
- const char dtype_str_int32[] = "'<i4'";
-
FILE* infile = nullptr;
NPError rc = HEADER_PARSE_ERROR;
assert(filename);
@@ -49,91 +40,58 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co
return FILE_NOT_FOUND;
}
- bool is_signed = false;
- int bit_length;
+ bool is_signed = false;
+ int length_per_byte = 0;
char byte_order;
- rc = getHeader(infile, is_signed, bit_length, byte_order);
+ rc = getHeader(infile, is_signed, length_per_byte, byte_order);
if (rc != NO_ERROR)
return rc;
- switch (bit_length)
+ switch (length_per_byte)
{
- case 1: // 8-bit
+ case 1:
if (is_signed)
{
- // int8
- int8_t* i8databuf = nullptr;
- i8databuf = (int8_t*)calloc(sizeof(i8databuf), elems);
-
- rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false);
-
- for (unsigned i = 0; i < elems; ++i)
- {
- databuf[i] = (int32_t)i8databuf[i];
- }
- free(i8databuf);
-
- return rc;
+ int8_t* tmp_buf = new int8_t[elems];
+ rc = readFromNpyFile<int8_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ free(tmp_buf);
}
else
{
- // uint8
- uint8_t* ui8databuf = nullptr;
- ui8databuf = (uint8_t*)calloc(sizeof(ui8databuf), elems);
-
- rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false);
-
- for (unsigned i = 0; i < elems; ++i)
- {
- databuf[i] = (int32_t)ui8databuf[i];
- }
- free(ui8databuf);
+ uint8_t* tmp_buf = new uint8_t[elems];
+ rc = readFromNpyFile<uint8_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ free(tmp_buf);
}
break;
- case 2: // 16-bit
+ case 2:
if (is_signed)
{
- // int16
- int16_t* i16databuf = nullptr;
- i16databuf = (int16_t*)calloc(sizeof(i16databuf), elems);
-
- rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false);
-
- for (unsigned i = 0; i < elems; ++i)
- {
- databuf[i] = (int32_t)i16databuf[i];
- }
- free(i16databuf);
-
- return rc;
+ int16_t* tmp_buf = new int16_t[elems];
+ rc = readFromNpyFile<int16_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ free(tmp_buf);
}
else
{
- // uint16
- uint16_t* ui16databuf = nullptr;
- ui16databuf = (uint16_t*)calloc(sizeof(ui16databuf), elems);
-
- rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false);
-
- for (unsigned i = 0; i < elems; ++i)
- {
- databuf[i] = (int32_t)ui16databuf[i];
- }
- free(ui16databuf);
-
- return rc;
+ uint16_t* tmp_buf = new uint16_t[elems];
+ rc = readFromNpyFile<uint16_t>(filename, elems, tmp_buf);
+ copyBufferByElement(databuf, tmp_buf, elems);
+ free(tmp_buf);
}
break;
- case 4: // 32-bit
+ case 4:
if (is_signed)
{
- // int32
- return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false);
+ bool is_bool;
+ const char* dtype_str = getDTypeString<int32_t>(is_bool);
+ rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool);
}
else
{
// uint32, not supported
- return DATA_TYPE_NOT_SUPPORTED;
+ rc = DATA_TYPE_NOT_SUPPORTED;
}
break;
default:
@@ -144,31 +102,6 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co
return rc;
}
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
-{
- const char dtype_str[] = "'<i8'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
-{
- const char dtype_str[] = "'<f4'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf)
-{
- const char dtype_str[] = "'<f8'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
-{
- const char dtype_str[] = "'<f2'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
-}
-
NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
const char* dtype_str,
const size_t elementsize,
@@ -418,138 +351,6 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3
return rc;
}
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
-{
- const char dtype_str[] = "'|b1'";
- return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf)
-{
- const char dtype_str[] = "'|u1'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(uint8_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf)
-{
- const char dtype_str[] = "'|i1'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(int8_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf)
-{
- const char dtype_str[] = "'<u2'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(uint16_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf)
-{
- const char dtype_str[] = "'<i2'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(int16_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
-{
- const char dtype_str[] = "'<i4'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
-{
- const char dtype_str[] = "'<i8'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
-{
- const char dtype_str[] = "'<f4'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf)
-{
- std::vector<int32_t> shape = { (int32_t)elems };
- return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
- NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf)
-{
- const char dtype_str[] = "'<f8'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
- const std::vector<int32_t>& shape,
- const half_float::half* databuf)
-{
- const char dtype_str[] = "'<f2'";
- return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
-}
-
NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
const char* dtype_str,
const size_t elementsize,