aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-07-28 13:42:29 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-08-02 11:11:08 -0700
commit3ce563449c1e607b016b82c5dbb6e33883f846a5 (patch)
tree1af52ab625cebe46632c27631005d423d79b7529
parent82dbb32980a58889bef28b7ad653c30694364195 (diff)
downloadserialization_lib-3ce563449c1e607b016b82c5dbb6e33883f846a5.tar.gz
Support I4 weights pack/unpack.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ia7d2bfaa473c8a92c71f075c86aca0a275245f83
-rw-r--r--include/tosa_serialization_handler.h2
-rw-r--r--src/tosa_serialization_handler.cpp67
2 files changed, 68 insertions, 1 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index db9481b..73254dd 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -291,6 +291,7 @@ public:
static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
@@ -298,6 +299,7 @@ public:
static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
+ static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
// version
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index d153dc5..4d69396 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -838,6 +838,37 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in
return TOSA_OK;
}
+// Two int4 values are packed into one byte out.
+// For given input value val_0 = in[2*i], and val_1 = in[2*i+1],
+// they'll be packed as out[3:0] = val_0, and out[7:4] = val_1
+tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out)
+{
+ out.clear();
+ uint32_t in_size = in.size();
+ uint32_t out_size = (in_size % 2 == 0) ? (in_size / 2) : ((in_size + 1) / 2);
+ for (int i = 0; i < out_size; i++)
+ {
+ int8_t val_0 = in[2 * i];
+ int8_t val_1 = 0;
+ if (2 * i + 1 < in_size)
+ {
+ val_1 = in[2 * i + 1];
+ }
+ // In TOSA spec, int4 ranges [-7, 7]
+ if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7)
+ {
+ printf("TosaSerializationHandler::ConvertI4toU8(): element in input array (%d or %d) exceeds int4 range.\n",
+ val_0, val_1);
+ return TOSA_USER_ERROR;
+ }
+ int8_t val_packed = (val_0 & 0xF) | ((val_1 & 0xF) << 4);
+ uint8_t val_u8 = static_cast<uint8_t>(val_packed);
+ out.push_back(val_u8);
+ }
+ zero_pad(out);
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out)
{
out.clear();
@@ -958,7 +989,7 @@ tosa_err_t
if (in.size() < out_size * sizeof(int8_t))
{
printf("TosaSerializationHandler::ConvertU8toI8(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
- out_size * sizeof(bool));
+ out_size * sizeof(int8_t));
return TOSA_USER_ERROR;
}
for (int i = 0; i < out_size; i++)
@@ -971,6 +1002,40 @@ tosa_err_t
}
tosa_err_t
+ TosaSerializationHandler::ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out)
+{
+ out.clear();
+ if (out_size > in.size() * 2)
+ {
+ printf("TosaSerializationHandler::ConvertU8toI4(): output size %u must <= uint8 buffer size %ld x 2.\n",
+ out_size, in.size());
+ return TOSA_USER_ERROR;
+ }
+ for (int i = 0; i < in.size(); i++)
+ {
+ uint8_t val_u8 = in[i];
+ uint8_t val_0_u4 = val_u8 & 0xF;
+ uint8_t val_1_u4 = val_u8 >> 4;
+ uint8_t val_0_u8_sext = (val_0_u4 & 0x08) ? (val_0_u4 | 0xF0) : val_0_u4;
+ uint8_t val_1_u8_sext = (val_1_u4 & 0x08) ? (val_1_u4 | 0xF0) : val_1_u4;
+ int8_t val_0 = static_cast<int8_t>(val_0_u8_sext);
+ int8_t val_1 = static_cast<int8_t>(val_1_u8_sext);
+ // In TOSA spec, int4 ranges [-7, 7]
+ if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7)
+ {
+ printf(
+ "TosaSerializationHandler::ConvertU8toI4(): element in output array (%d or %d) exceeds int4 range.\n",
+ val_0, val_1);
+ return TOSA_USER_ERROR;
+ }
+ out.push_back(val_0);
+ if (2 * i + 1 < out_size)
+ out.push_back(val_1);
+ }
+ return TOSA_OK;
+}
+
+tosa_err_t
TosaSerializationHandler::ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out)
{
out.clear();