diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/numpy_utils.cpp | 16 | ||||
-rw-r--r-- | src/tosa_serialization_handler.cpp | 44 |
2 files changed, 60 insertions, 0 deletions
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp index 80c680f..c770d45 100644 --- a/src/numpy_utils.cpp +++ b/src/numpy_utils.cpp @@ -14,6 +14,7 @@ // limitations under the License. #include "numpy_utils.h" +#include "half.hpp" // Magic NUMPY header static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{"; @@ -45,6 +46,13 @@ NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, co return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false); } +NumpyUtilities::NPError + NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf) +{ + const char dtype_str[] = "'<f2'"; + return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false); +} + NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename, const char* dtype_str, const size_t elementsize, @@ -307,6 +315,14 @@ NumpyUtilities::NPError return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false); } +NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, + const std::vector<int32_t>& shape, + const half_float::half* databuf) +{ + const char dtype_str[] = "'<f2'"; + return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false); +} + NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename, const char* dtype_str, const size_t elementsize, diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 3a0ce43..170b313 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -14,6 +14,7 @@ // limitations under the License. #include "tosa_serialization_handler.h" +#include "half.hpp" #include <iostream> using namespace tosa; @@ -652,6 +653,7 @@ tosa_err_t TosaSerializationHandler::Serialize() #define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_DType(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V) #define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V) @@ -692,6 +694,7 @@ tosa_err_t TosaSerializationHandler::Serialize() #undef DEF_ARGS_S_float #undef DEF_ARGS_S_bool #undef DEF_ARGS_S_ResizeMode +#undef DEF_ARGS_S_DType #undef DEF_ARGS_S_string #undef DEF_ARGS_S_STR #undef DEF_ARGS_S_DEFAULT @@ -746,6 +749,21 @@ void zero_pad(std::vector<uint8_t>& buf) } } +tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +{ + // Note: Converts fp32->fp16 before converting to uint8_t + out.clear(); + for (auto val : in) + { + half_float::half val_f16 = half_float::half_cast<half_float::half, float>(val); + uint16_t* val_u16 = reinterpret_cast<uint16_t*>(&val_f16); + out.push_back(*val_u16 & 0xFF); + out.push_back((*val_u16 >> 8) & 0xFF); + } + zero_pad(out); + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out) { out.clear(); @@ -862,6 +880,32 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in } tosa_err_t + TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out) +{ + // Note: fp16 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int16_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int16_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + uint16_t f16_byte0 = in[i * sizeof(int16_t)]; + uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); + + // Reinterpret u16 byte as fp16 then convert to fp32 + half_float::half val_f16 = *(half_float::half*)&val_u16; + float val_fp32 = half_float::half_cast<float, half_float::half>(val_f16); + out.push_back(val_fp32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out) { out.clear(); |