From a9017401461224b9bc81e7b1c770ca6091e0e3fb Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 28 Jul 2021 17:19:23 -0700 Subject: Support int4 weights read. Added conv2d int8xint4 in test generation. Signed-off-by: Kevin Cheng Change-Id: I61620f160c7dad6aac5fcc3da0a6e97f3bae5b40 --- reference_model/src/subgraph_traverser.cc | 9 ++++++++- thirdparty/serialization_lib | 2 +- verif/tosa_serializer.py | 12 ++++++++++++ verif/tosa_test_gen.py | 7 +++++-- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index d64cb38..bdf6fbc 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -225,7 +225,6 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); TosaReference::Tensor* tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); - if (!ts->GetData().empty()) { if (tensor->allocate()) @@ -236,6 +235,14 @@ int SubgraphTraverser::initializeGraph() switch (ts->GetDtype()) { + case DType_INT4: + { + std::vector i4_data; + TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data); + std::vector i32_data(i4_data.begin(), i4_data.end()); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; case DType_INT8: { std::vector i8_data; diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 82dbb32..3ce5634 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 82dbb32980a58889bef28b7ad653c30694364195 +Subproject commit 3ce563449c1e607b016b82c5dbb6e33883f846a5 diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py index c4de2a2..b4daaad 100644 --- a/verif/tosa_serializer.py +++ b/verif/tosa_serializer.py @@ -386,6 +386,18 @@ class TosaSerializerTensor: for val in self.data: val_u8 = np.uint8(val) u8_data.append(val_u8) + elif self.dtype == DType.INT4: + in_size = len(self.data) + out_size = (in_size + 1) // 2 + for i in range(out_size): + val_0 = self.data[2 * i] + if (2 * i + 1) < in_size: + val_1 = self.data[2 * i + 1] + else: + val_1 = 0 + val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4) + val_u8 = np.uint8(val_i8) + u8_data.append(val_u8) elif self.dtype == DType.INT8: for val in self.data: val_u8 = np.uint8(val) diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index e375a2a..5c25f8e 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -923,8 +923,9 @@ class TosaTestGen: if dtype == DType.BOOL: np_dt = np.bool return np.bool_(self.rng.choice(a=[False, True], size=shape)) + # TOSA specific INT4 weight range from -7 to 7 elif dtype == DType.INT4: - return np.int32(self.rng.integers(low=-8, high=8, size=shape)) + return np.int32(self.rng.integers(low=-7, high=8, size=shape)) elif dtype == DType.INT8: return np.int32(self.rng.integers(low=-128, high=128, size=shape)) elif dtype == DType.UINT8: @@ -988,8 +989,9 @@ class TosaTestGen: return self.rng.random() elif dtype == DType.BOOL: return self.rng.choice([False, True]) + # TOSA specific INT4 weight range from -7 to 7 elif dtype == DType.INT4: - low, high = (-8, 8) + low, high = (-7, 8) elif dtype == DType.INT8: low, high = (-128, 128) elif dtype == DType.INT16: @@ -1977,6 +1979,7 @@ class TosaTestGen: TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT] TYPE_CONV2D = [ + [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], DType.FLOAT, -- cgit v1.2.1