aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-07-03 16:36:41 +0000
committerJerry Ge <jerry.ge@arm.com>2023-07-05 20:09:41 +0000
commit13a329156f5ae51a0ffffcb0083d807da9e5b19f (patch)
tree90dd3dad7eb256a16905d7b55433ab6505e4015b
parent3acb1cbfdd492ab4f657799ed0c2279a5e390248 (diff)
downloadserialization_lib-13a329156f5ae51a0ffffcb0083d807da9e5b19f.tar.gz
Support reading anydtype into a 32-bit buffer
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ic6b43539fcb2d75c5614d3addccd24a06e9f2a31
-rw-r--r--include/numpy_utils.h12
-rw-r--r--src/numpy_utils.cpp180
2 files changed, 158 insertions, 34 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 29d7e11..e9c4bb4 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,6 +37,7 @@ public:
FILE_TYPE_MISMATCH,
HEADER_PARSE_ERROR,
BUFFER_SIZE_MISMATCH,
+ DATA_TYPE_NOT_SUPPORTED,
};
static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
@@ -45,14 +46,6 @@ public:
static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf);
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, uint8_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int8_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, uint16_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int16_t* databuf);
-
static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf);
@@ -112,6 +105,7 @@ private:
void* databuf,
bool bool_translate);
static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
+ static NPError getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order);
static NPError writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str);
};
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index 65d76e3..d31ec1c 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,12 +15,14 @@
#include "numpy_utils.h"
#include "half.hpp"
-
+#include <algorithm>
// Magic NUMPY header
static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
static const int NUMPY_HEADER_SZ = 128;
// Maximum shape dimensions supported
static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
+// Offset for NUMPY header desc dictionary string
+static const int NUMPY_HEADER_DESC_OFFSET = 8;
NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
{
@@ -28,34 +30,118 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co
return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
}
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint8_t* databuf)
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
{
- const char dtype_str[] = "'|u1'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(uint8_t), elems, databuf, false);
-}
+ const char dtype_str_uint8[] = "'|u1'";
+ const char dtype_str_int8[] = "'|i1'";
+ const char dtype_str_uint16[] = "'<u2'";
+ const char dtype_str_int16[] = "'<i2'";
+ const char dtype_str_int32[] = "'<i4'";
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int8_t* databuf)
-{
- const char dtype_str[] = "'|i1'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(int8_t), elems, databuf, false);
-}
+ FILE* infile = nullptr;
+ NPError rc = HEADER_PARSE_ERROR;
+ assert(filename);
+ assert(databuf);
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint16_t* databuf)
-{
- const char dtype_str[] = "'<u2'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(uint16_t), elems, databuf, false);
-}
+ infile = fopen(filename, "rb");
+ if (!infile)
+ {
+ return FILE_NOT_FOUND;
+ }
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int16_t* databuf)
-{
- const char dtype_str[] = "'<i2'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(int16_t), elems, databuf, false);
-}
+ bool is_signed = false;
+ int bit_length;
+ char byte_order;
+ rc = getHeader(infile, is_signed, bit_length, byte_order);
+ if (rc != NO_ERROR)
+ return rc;
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
-{
- const char dtype_str[] = "'<i4'";
- return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
+ switch (bit_length)
+ {
+ case 1: // 8-bit
+ if (is_signed)
+ {
+ // int8
+ int8_t* i8databuf = nullptr;
+ i8databuf = (int8_t*)calloc(sizeof(i8databuf), elems);
+
+ rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false);
+
+ for (unsigned i = 0; i < elems; ++i)
+ {
+ databuf[i] = (int32_t)i8databuf[i];
+ }
+ free(i8databuf);
+
+ return rc;
+ }
+ else
+ {
+ // uint8
+ uint8_t* ui8databuf = nullptr;
+ ui8databuf = (uint8_t*)calloc(sizeof(ui8databuf), elems);
+
+ rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false);
+
+ for (unsigned i = 0; i < elems; ++i)
+ {
+ databuf[i] = (int32_t)ui8databuf[i];
+ }
+ free(ui8databuf);
+ }
+ break;
+ case 2: // 16-bit
+ if (is_signed)
+ {
+ // int16
+ int16_t* i16databuf = nullptr;
+ i16databuf = (int16_t*)calloc(sizeof(i16databuf), elems);
+
+ rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false);
+
+ for (unsigned i = 0; i < elems; ++i)
+ {
+ databuf[i] = (int32_t)i16databuf[i];
+ }
+ free(i16databuf);
+
+ return rc;
+ }
+ else
+ {
+ // uint16
+ uint16_t* ui16databuf = nullptr;
+ ui16databuf = (uint16_t*)calloc(sizeof(ui16databuf), elems);
+
+ rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false);
+
+ for (unsigned i = 0; i < elems; ++i)
+ {
+ databuf[i] = (int32_t)ui16databuf[i];
+ }
+ free(ui16databuf);
+
+ return rc;
+ }
+ break;
+ case 4: // 32-bit
+ if (is_signed)
+ {
+ // int32
+ return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false);
+ }
+ else
+ {
+ // uint32, not supported
+ return DATA_TYPE_NOT_SUPPORTED;
+ }
+ break;
+ default:
+ return DATA_TYPE_NOT_SUPPORTED;
+ break;
+ }
+
+ return rc;
}
NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
@@ -139,6 +225,50 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filena
return rc;
}
+NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
+{
+ char buf[NUMPY_HEADER_SZ + 1];
+ NPError rc = NO_ERROR;
+ assert(infile);
+
+ if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
+ {
+ return HEADER_PARSE_ERROR;
+ }
+ char* ptr;
+ ptr = buf + sizeof(NUMPY_HEADER_STR) - 1;
+
+ std::string dic_string(ptr);
+ auto descr_loc = dic_string.find("descr");
+
+ // Reference: https://en.cppreference.com/w/cpp/algorithm/remove
+ // remove all the white spaces for the following offset NUMPY_HEADER_DESC_OFFSET to work
+ dic_string.erase(
+ std::remove_if(dic_string.begin(), dic_string.end(), [](unsigned char x) { return std::isspace(x); }),
+ dic_string.end());
+ // The dic_string is constant: descr': ', add a offset of NUMPY_HEADER_DESC_OFFSET
+ // to the actual dtype string station
+ dic_string = dic_string.substr(descr_loc + NUMPY_HEADER_DESC_OFFSET, 3);
+
+ // Fill byte_order;
+ char byte_order_c[1];
+ strcpy(byte_order_c, dic_string.substr(0, 1).c_str());
+ byte_order = byte_order_c[0];
+
+ // Fill is_signed
+ char is_signed_c[1];
+ strcpy(is_signed_c, dic_string.substr(1, 1).c_str());
+ is_signed = is_signed_c[0] == 'u' ? false : true;
+
+ // Fill bit_length
+ char bit_length_c[1];
+ strcpy(bit_length_c, dic_string.substr(2, 1).c_str());
+ bit_length = (int)(bit_length_c[0] - '0');
+
+ rewind(infile);
+ return rc;
+}
+
NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
{
char buf[NUMPY_HEADER_SZ + 1];