diff options
Diffstat (limited to 'src/numpy_utils.cpp')
-rw-r--r-- | src/numpy_utils.cpp | 197 |
1 files changed, 119 insertions, 78 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index 80c680f..e4171d7 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,6 +14,9 @@ // limitations under the License. #include "numpy_utils.h" +#include "half.hpp" +#include <algorithm> +#include <memory> // Magic NUMPY header static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{"; @@ -21,28 +24,81 @@ static const int NUMPY_HEADER_SZ = 128; // Maximum shape dimensions supported static const int NUMPY_MAX_DIMS_SUPPORTED = 10; -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf) -{ - const char dtype_str[] = "'|b1'"; - return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true); -} - +// This is an entry function for reading 8-/16-/32-bit npy file. +template <> NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf) { - const char dtype_str[] = "'<i4'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false); -} + FILE* infile = nullptr; + NPError rc = HEADER_PARSE_ERROR; + assert(filename); + assert(databuf); -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf) -{ - const char dtype_str[] = "'<i8'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false); -} + infile = fopen(filename, "rb"); + if (!infile) + { + return FILE_NOT_FOUND; + } -NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf) -{ - const char dtype_str[] = "'<f4'"; - return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false); + bool is_signed = false; + int length_per_byte = 0; + char byte_order; + rc = getHeader(infile, is_signed, length_per_byte, byte_order); + if (rc != NO_ERROR) + return rc; + + switch (length_per_byte) + { + case 1: + if (is_signed) + { + int8_t* tmp_buf = new int8_t[elems]; + rc = readFromNpyFile<int8_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + delete[] tmp_buf; + } + else + { + uint8_t* tmp_buf = new uint8_t[elems]; + rc = readFromNpyFile<uint8_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + delete[] tmp_buf; + } + break; + case 2: + if (is_signed) + { + int16_t* tmp_buf = new int16_t[elems]; + rc = readFromNpyFile<int16_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + delete[] tmp_buf; + } + else + { + uint16_t* tmp_buf = new uint16_t[elems]; + rc = readFromNpyFile<uint16_t>(filename, elems, tmp_buf); + copyBufferByElement(databuf, tmp_buf, elems); + delete[] tmp_buf; + } + break; + case 4: + if (is_signed) + { + bool is_bool; + const char* dtype_str = getDTypeString<int32_t>(is_bool); + rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool); + } + else + { + // uint32, not supported + rc = DATA_TYPE_NOT_SUPPORTED; + } + break; + default: + return DATA_TYPE_NOT_SUPPORTED; + break; + } + + return rc; } NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename, @@ -101,6 +157,46 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filena return rc; } +NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order) +{ + char buf[NUMPY_HEADER_SZ + 1]; + NPError rc = NO_ERROR; + assert(infile); + + if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1) + { + return HEADER_PARSE_ERROR; + } + + // Validate the numpy magic number + if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1)) + { + return HEADER_PARSE_ERROR; + } + + std::string dic_string(buf, NUMPY_HEADER_SZ); + + std::string desc_str("descr':"); + size_t offset = dic_string.find(desc_str); + if (offset == std::string::npos) + return HEADER_PARSE_ERROR; + + offset += desc_str.size() + 1; + // Skip whitespace and the opening ' + while (offset < dic_string.size() && (std::isspace(dic_string[offset]) || dic_string[offset] == '\'')) + offset++; + // Check for overflow + if (offset + 2 > dic_string.size()) + return HEADER_PARSE_ERROR; + + byte_order = dic_string[offset]; + is_signed = dic_string[offset + 1] == 'u' ? false : true; + bit_length = (int)dic_string[offset + 2] - '0'; + + rewind(infile); + return rc; +} + NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str) { char buf[NUMPY_HEADER_SZ + 1]; @@ -253,60 +349,6 @@ NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint3 return rc; } -NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf) -{ - const char dtype_str[] = "'|b1'"; - return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1 -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf) -{ - const char dtype_str[] = "'<i4'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf) -{ - const char dtype_str[] = "'<i8'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false); -} - -NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf) -{ - std::vector<int32_t> shape = { (int32_t)elems }; - return writeToNpyFile(filename, shape, databuf); -} - -NumpyUtilities::NPError - NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf) -{ - const char dtype_str[] = "'<f4'"; - return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false); -} - NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename, const char* dtype_str, const size_t elementsize, @@ -390,12 +432,11 @@ NumpyUtilities::NPError // Output the format dictionary // Hard-coded for I32 for now - headerPos += - snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,", - dtype_str, shape.empty() ? 1 : shape[0]); + headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, + "'descr': %s, 'fortran_order': False, 'shape': (", dtype_str); - // Remainder of shape array - for (i = 1; i < shape.size(); i++) + // Add shape contents (if any - as this will be empty for rank 0) + for (i = 0; i < shape.size(); i++) { headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]); } |