diff options
-rw-r--r-- | include/tosa_serialization_handler.h | 2 | ||||
-rw-r--r-- | src/tosa_serialization_handler.cpp | 67 |
2 files changed, 68 insertions, 1 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index db9481b..73254dd 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -291,6 +291,7 @@ public: static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); @@ -298,6 +299,7 @@ public: static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out); static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out); static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out); + static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out); static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out); // version diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index d153dc5..4d69396 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -838,6 +838,37 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in return TOSA_OK; } +// Two int4 values are packed into one byte out. +// For given input value val_0 = in[2*i], and val_1 = in[2*i+1], +// they'll be packed as out[3:0] = val_0, and out[7:4] = val_1 +tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out) +{ + out.clear(); + uint32_t in_size = in.size(); + uint32_t out_size = (in_size % 2 == 0) ? (in_size / 2) : ((in_size + 1) / 2); + for (int i = 0; i < out_size; i++) + { + int8_t val_0 = in[2 * i]; + int8_t val_1 = 0; + if (2 * i + 1 < in_size) + { + val_1 = in[2 * i + 1]; + } + // In TOSA spec, int4 ranges [-7, 7] + if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7) + { + printf("TosaSerializationHandler::ConvertI4toU8(): element in input array (%d or %d) exceeds int4 range.\n", + val_0, val_1); + return TOSA_USER_ERROR; + } + int8_t val_packed = (val_0 & 0xF) | ((val_1 & 0xF) << 4); + uint8_t val_u8 = static_cast<uint8_t>(val_packed); + out.push_back(val_u8); + } + zero_pad(out); + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out) { out.clear(); @@ -958,7 +989,7 @@ tosa_err_t if (in.size() < out_size * sizeof(int8_t)) { printf("TosaSerializationHandler::ConvertU8toI8(): uint8 buffer size %ld must >= target size %ld\n", in.size(), - out_size * sizeof(bool)); + out_size * sizeof(int8_t)); return TOSA_USER_ERROR; } for (int i = 0; i < out_size; i++) @@ -971,6 +1002,40 @@ tosa_err_t } tosa_err_t + TosaSerializationHandler::ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out) +{ + out.clear(); + if (out_size > in.size() * 2) + { + printf("TosaSerializationHandler::ConvertU8toI4(): output size %u must <= uint8 buffer size %ld x 2.\n", + out_size, in.size()); + return TOSA_USER_ERROR; + } + for (int i = 0; i < in.size(); i++) + { + uint8_t val_u8 = in[i]; + uint8_t val_0_u4 = val_u8 & 0xF; + uint8_t val_1_u4 = val_u8 >> 4; + uint8_t val_0_u8_sext = (val_0_u4 & 0x08) ? (val_0_u4 | 0xF0) : val_0_u4; + uint8_t val_1_u8_sext = (val_1_u4 & 0x08) ? (val_1_u4 | 0xF0) : val_1_u4; + int8_t val_0 = static_cast<int8_t>(val_0_u8_sext); + int8_t val_1 = static_cast<int8_t>(val_1_u8_sext); + // In TOSA spec, int4 ranges [-7, 7] + if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7) + { + printf( + "TosaSerializationHandler::ConvertU8toI4(): element in output array (%d or %d) exceeds int4 range.\n", + val_0, val_1); + return TOSA_USER_ERROR; + } + out.push_back(val_0); + if (2 * i + 1 < out_size) + out.push_back(val_1); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out) { out.clear(); |