From a9017401461224b9bc81e7b1c770ca6091e0e3fb Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 28 Jul 2021 17:19:23 -0700 Subject: Support int4 weights read. Added conv2d int8xint4 in test generation. Signed-off-by: Kevin Cheng Change-Id: I61620f160c7dad6aac5fcc3da0a6e97f3bae5b40 --- reference_model/src/subgraph_traverser.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'reference_model/src/subgraph_traverser.cc') diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index d64cb38..bdf6fbc 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -225,7 +225,6 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); TosaReference::Tensor* tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); - if (!ts->GetData().empty()) { if (tensor->allocate()) @@ -236,6 +235,14 @@ int SubgraphTraverser::initializeGraph() switch (ts->GetDtype()) { + case DType_INT4: + { + std::vector i4_data; + TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data); + std::vector i32_data(i4_data.begin(), i4_data.end()); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; case DType_INT8: { std::vector i8_data; -- cgit v1.2.1