From 3ce563449c1e607b016b82c5dbb6e33883f846a5 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 28 Jul 2021 13:42:29 -0700 Subject: Support I4 weights pack/unpack. Signed-off-by: Kevin Cheng Change-Id: Ia7d2bfaa473c8a92c71f075c86aca0a275245f83 --- include/tosa_serialization_handler.h | 2 ++ src/tosa_serialization_handler.cpp | 67 +++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index db9481b..73254dd 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -291,6 +291,7 @@ public: static tosa_err_t ConvertI32toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI16toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertI8toU8(const std::vector& in, std::vector& out); + static tosa_err_t ConvertI4toU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertBooltoU8(const std::vector& in, std::vector& out); static tosa_err_t ConvertU8toF32(const std::vector& in, uint32_t out_size, std::vector& out); @@ -298,6 +299,7 @@ public: static tosa_err_t ConvertU8toI32(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toI16(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toI8(const std::vector& in, uint32_t out_size, std::vector& out); + static tosa_err_t ConvertU8toI4(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toBool(const std::vector& in, uint32_t out_size, std::vector& out); // version diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index d153dc5..4d69396 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -838,6 +838,37 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector& in return TOSA_OK; } +// Two int4 values are packed into one byte out. +// For given input value val_0 = in[2*i], and val_1 = in[2*i+1], +// they'll be packed as out[3:0] = val_0, and out[7:4] = val_1 +tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector& in, std::vector& out) +{ + out.clear(); + uint32_t in_size = in.size(); + uint32_t out_size = (in_size % 2 == 0) ? (in_size / 2) : ((in_size + 1) / 2); + for (int i = 0; i < out_size; i++) + { + int8_t val_0 = in[2 * i]; + int8_t val_1 = 0; + if (2 * i + 1 < in_size) + { + val_1 = in[2 * i + 1]; + } + // In TOSA spec, int4 ranges [-7, 7] + if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7) + { + printf("TosaSerializationHandler::ConvertI4toU8(): element in input array (%d or %d) exceeds int4 range.\n", + val_0, val_1); + return TOSA_USER_ERROR; + } + int8_t val_packed = (val_0 & 0xF) | ((val_1 & 0xF) << 4); + uint8_t val_u8 = static_cast(val_packed); + out.push_back(val_u8); + } + zero_pad(out); + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector& in, std::vector& out) { out.clear(); @@ -958,7 +989,7 @@ tosa_err_t if (in.size() < out_size * sizeof(int8_t)) { printf("TosaSerializationHandler::ConvertU8toI8(): uint8 buffer size %ld must >= target size %ld\n", in.size(), - out_size * sizeof(bool)); + out_size * sizeof(int8_t)); return TOSA_USER_ERROR; } for (int i = 0; i < out_size; i++) @@ -970,6 +1001,40 @@ tosa_err_t return TOSA_OK; } +tosa_err_t + TosaSerializationHandler::ConvertU8toI4(const std::vector& in, uint32_t out_size, std::vector& out) +{ + out.clear(); + if (out_size > in.size() * 2) + { + printf("TosaSerializationHandler::ConvertU8toI4(): output size %u must <= uint8 buffer size %ld x 2.\n", + out_size, in.size()); + return TOSA_USER_ERROR; + } + for (int i = 0; i < in.size(); i++) + { + uint8_t val_u8 = in[i]; + uint8_t val_0_u4 = val_u8 & 0xF; + uint8_t val_1_u4 = val_u8 >> 4; + uint8_t val_0_u8_sext = (val_0_u4 & 0x08) ? (val_0_u4 | 0xF0) : val_0_u4; + uint8_t val_1_u8_sext = (val_1_u4 & 0x08) ? (val_1_u4 | 0xF0) : val_1_u4; + int8_t val_0 = static_cast(val_0_u8_sext); + int8_t val_1 = static_cast(val_1_u8_sext); + // In TOSA spec, int4 ranges [-7, 7] + if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7) + { + printf( + "TosaSerializationHandler::ConvertU8toI4(): element in output array (%d or %d) exceeds int4 range.\n", + val_0, val_1); + return TOSA_USER_ERROR; + } + out.push_back(val_0); + if (2 * i + 1 < out_size) + out.push_back(val_1); + } + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertU8toBool(const std::vector& in, uint32_t out_size, std::vector& out) { -- cgit v1.2.1