diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/numpy_utils.cpp | 180 |
1 files changed, 155 insertions, 25 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index 65d76e3..d31ec1c 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,12 +15,14 @@ #include "numpy_utils.h" #include "half.hpp" - +#include <algorithm> // Magic NUMPY header static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{"; static const int NUMPY_HEADER_SZ = 128; // Maximum shape dimensions supported static const int NUMPY_MAX_DIMS_SUPPORTED = 10; +// Offset for NUMPY header desc dictionary string +static const int NUMPY_HEADER_DESC_OFFSET = 8; NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf) { @@ -28,34 +30,118 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true); } -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint8_t* databuf) +NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf) { - const char dtype_str[] = "'|u1'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(uint8_t), elems, databuf, false); -} + const char dtype_str_uint8[] = "'|u1'"; + const char dtype_str_int8[] = "'|i1'"; + const char dtype_str_uint16[] = "'<u2'"; + const char dtype_str_int16[] = "'<i2'"; + const char dtype_str_int32[] = "'<i4'"; -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int8_t* databuf) -{ - const char dtype_str[] = "'|i1'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(int8_t), elems, databuf, false); -} + FILE* infile = nullptr; + NPError rc = HEADER_PARSE_ERROR; + assert(filename); + assert(databuf); -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint16_t* databuf) -{ - const char dtype_str[] = "'<u2'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(uint16_t), elems, databuf, false); -} + infile = fopen(filename, "rb"); + if (!infile) + { + return FILE_NOT_FOUND; + } -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int16_t* databuf) -{ - const char dtype_str[] = "'<i2'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(int16_t), elems, databuf, false); -} + bool is_signed = false; + int bit_length; + char byte_order; + rc = getHeader(infile, is_signed, bit_length, byte_order); + if (rc != NO_ERROR) + return rc; -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf) -{ - const char dtype_str[] = "'<i4'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false); + switch (bit_length) + { + case 1: // 8-bit + if (is_signed) + { + // int8 + int8_t* i8databuf = nullptr; + i8databuf = (int8_t*)calloc(sizeof(i8databuf), elems); + + rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false); + + for (unsigned i = 0; i < elems; ++i) + { + databuf[i] = (int32_t)i8databuf[i]; + } + free(i8databuf); + + return rc; + } + else + { + // uint8 + uint8_t* ui8databuf = nullptr; + ui8databuf = (uint8_t*)calloc(sizeof(ui8databuf), elems); + + rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false); + + for (unsigned i = 0; i < elems; ++i) + { + databuf[i] = (int32_t)ui8databuf[i]; + } + free(ui8databuf); + } + break; + case 2: // 16-bit + if (is_signed) + { + // int16 + int16_t* i16databuf = nullptr; + i16databuf = (int16_t*)calloc(sizeof(i16databuf), elems); + + rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false); + + for (unsigned i = 0; i < elems; ++i) + { + databuf[i] = (int32_t)i16databuf[i]; + } + free(i16databuf); + + return rc; + } + else + { + // uint16 + uint16_t* ui16databuf = nullptr; + ui16databuf = (uint16_t*)calloc(sizeof(ui16databuf), elems); + + rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false); + + for (unsigned i = 0; i < elems; ++i) + { + databuf[i] = (int32_t)ui16databuf[i]; + } + free(ui16databuf); + + return rc; + } + break; + case 4: // 32-bit + if (is_signed) + { + // int32 + return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false); + } + else + { + // uint32, not supported + return DATA_TYPE_NOT_SUPPORTED; + } + break; + default: + return DATA_TYPE_NOT_SUPPORTED; + break; + } + + return rc; } NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf) @@ -139,6 +225,50 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filena return rc; } +NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order) +{ + char buf[NUMPY_HEADER_SZ + 1]; + NPError rc = NO_ERROR; + assert(infile); + + if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1) + { + return HEADER_PARSE_ERROR; + } + char* ptr; + ptr = buf + sizeof(NUMPY_HEADER_STR) - 1; + + std::string dic_string(ptr); + auto descr_loc = dic_string.find("descr"); + + // Reference: https://en.cppreference.com/w/cpp/algorithm/remove + // remove all the white spaces for the following offset NUMPY_HEADER_DESC_OFFSET to work + dic_string.erase( + std::remove_if(dic_string.begin(), dic_string.end(), [](unsigned char x) { return std::isspace(x); }), + dic_string.end()); + // The dic_string is constant: descr': ', add a offset of NUMPY_HEADER_DESC_OFFSET + // to the actual dtype string station + dic_string = dic_string.substr(descr_loc + NUMPY_HEADER_DESC_OFFSET, 3); + + // Fill byte_order; + char byte_order_c[1]; + strcpy(byte_order_c, dic_string.substr(0, 1).c_str()); + byte_order = byte_order_c[0]; + + // Fill is_signed + char is_signed_c[1]; + strcpy(is_signed_c, dic_string.substr(1, 1).c_str()); + is_signed = is_signed_c[0] == 'u' ? false : true; + + // Fill bit_length + char bit_length_c[1]; + strcpy(bit_length_c, dic_string.substr(2, 1).c_str()); + bit_length = (int)(bit_length_c[0] - '0'); + + rewind(infile); + return rc; +} + NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str) { char buf[NUMPY_HEADER_SZ + 1]; |