aboutsummaryrefslogtreecommitdiff
path: root/include/numpy_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/numpy_utils.h')
-rw-r--r--include/numpy_utils.h104
1 files changed, 84 insertions, 20 deletions
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index c64bc17..60cf77e 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,6 +24,8 @@
#include <cstring>
#include <vector>
+#include "half.hpp"
+
class NumpyUtilities
{
public:
@@ -35,31 +37,89 @@ public:
FILE_TYPE_MISMATCH,
HEADER_PARSE_ERROR,
BUFFER_SIZE_MISMATCH,
+ DATA_TYPE_NOT_SUPPORTED,
};
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf);
+ template <typename T>
+ static const char* getDTypeString(bool& is_bool)
+ {
+ is_bool = false;
+ if (std::is_same<T, bool>::value)
+ {
+ is_bool = true;
+ return "'|b1'";
+ }
+ if (std::is_same<T, uint8_t>::value)
+ {
+ return "'|u1'";
+ }
+ if (std::is_same<T, int8_t>::value)
+ {
+ return "'|i1'";
+ }
+ if (std::is_same<T, uint16_t>::value)
+ {
+ return "'<u2'";
+ }
+ if (std::is_same<T, int16_t>::value)
+ {
+ return "'<i2'";
+ }
+ if (std::is_same<T, int32_t>::value)
+ {
+ return "'<i4'";
+ }
+ if (std::is_same<T, int64_t>::value)
+ {
+ return "'<i8'";
+ }
+ if (std::is_same<T, float>::value)
+ {
+ return "'<f4'";
+ }
+ if (std::is_same<T, double>::value)
+ {
+ return "'<f8'";
+ }
+ if (std::is_same<T, half_float::half>::value)
+ {
+ return "'<f2'";
+ }
+ assert(false && "unsupported Dtype");
+ };
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf);
+ template <typename T>
+ static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf)
+ {
+ std::vector<int32_t> shape = { static_cast<int32_t>(elems) };
+ return writeToNpyFile(filename, shape, databuf);
+ }
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf);
+ template <typename T>
+ static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf)
+ {
+ bool is_bool;
+ const char* dtype_str = getDTypeString<T>(is_bool);
+ return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool);
+ }
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf);
+ template <typename T>
+ static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf)
+ {
+ bool is_bool;
+ const char* dtype_str = getDTypeString<T>(is_bool);
+ return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool);
+ }
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf);
+ template <typename D, typename S>
+ static void copyBufferByElement(D* dest_buf, S* src_buf, int num)
+ {
+ static_assert(sizeof(D) >= sizeof(S), "The size of dest_buf must be equal to or larger than that of src_buf");
+ for (int i = 0; i < num; ++i)
+ {
+ dest_buf[i] = src_buf[i];
+ }
+ }
private:
static NPError writeToNpyFileCommon(const char* filename,
@@ -75,7 +135,11 @@ private:
void* databuf,
bool bool_translate);
static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
+ static NPError getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order);
static NPError writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str);
};
+template <>
+NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
+
#endif // _TOSA_NUMPY_UTILS_H