aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-07-28 17:19:23 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-08-12 15:27:03 -0700
commita9017401461224b9bc81e7b1c770ca6091e0e3fb (patch)
treea9eba569a9ed3862f459b3190787cdbac57da85b
parent7ffccce9b9a7eeecbad0c0525f1545c2714f8556 (diff)
downloadreference_model-a9017401461224b9bc81e7b1c770ca6091e0e3fb.tar.gz
Support int4 weights read. Added conv2d int8xint4 in test generation.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I61620f160c7dad6aac5fcc3da0a6e97f3bae5b40
-rw-r--r--reference_model/src/subgraph_traverser.cc9
m---------thirdparty/serialization_lib0
-rw-r--r--verif/tosa_serializer.py12
-rw-r--r--verif/tosa_test_gen.py7
4 files changed, 25 insertions, 3 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index d64cb38..bdf6fbc 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -225,7 +225,6 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
TosaReference::Tensor* tensor =
TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
-
if (!ts->GetData().empty())
{
if (tensor->allocate())
@@ -236,6 +235,14 @@ int SubgraphTraverser::initializeGraph()
switch (ts->GetDtype())
{
+ case DType_INT4:
+ {
+ std::vector<int8_t> i4_data;
+ TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
+ std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
case DType_INT8:
{
std::vector<int8_t> i8_data;
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 82dbb32980a58889bef28b7ad653c3069436419
+Subproject 3ce563449c1e607b016b82c5dbb6e33883f846a
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index c4de2a2..b4daaad 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -386,6 +386,18 @@ class TosaSerializerTensor:
for val in self.data:
val_u8 = np.uint8(val)
u8_data.append(val_u8)
+ elif self.dtype == DType.INT4:
+ in_size = len(self.data)
+ out_size = (in_size + 1) // 2
+ for i in range(out_size):
+ val_0 = self.data[2 * i]
+ if (2 * i + 1) < in_size:
+ val_1 = self.data[2 * i + 1]
+ else:
+ val_1 = 0
+ val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
+ val_u8 = np.uint8(val_i8)
+ u8_data.append(val_u8)
elif self.dtype == DType.INT8:
for val in self.data:
val_u8 = np.uint8(val)
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index e375a2a..5c25f8e 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -923,8 +923,9 @@ class TosaTestGen:
if dtype == DType.BOOL:
np_dt = np.bool
return np.bool_(self.rng.choice(a=[False, True], size=shape))
+ # TOSA specific INT4 weight range from -7 to 7
elif dtype == DType.INT4:
- return np.int32(self.rng.integers(low=-8, high=8, size=shape))
+ return np.int32(self.rng.integers(low=-7, high=8, size=shape))
elif dtype == DType.INT8:
return np.int32(self.rng.integers(low=-128, high=128, size=shape))
elif dtype == DType.UINT8:
@@ -988,8 +989,9 @@ class TosaTestGen:
return self.rng.random()
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
+ # TOSA specific INT4 weight range from -7 to 7
elif dtype == DType.INT4:
- low, high = (-8, 8)
+ low, high = (-7, 8)
elif dtype == DType.INT8:
low, high = (-128, 128)
elif dtype == DType.INT16:
@@ -1977,6 +1979,7 @@ class TosaTestGen:
TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
TYPE_CONV2D = [
+ [DType.INT8, DType.INT4, DType.INT32],
[DType.INT8, DType.INT8, DType.INT32],
[DType.INT16, DType.INT8, DType.INT48],
DType.FLOAT,