From 8690a0873fac28acccbb0acb15c16e8337e14df1 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 18 Dec 2023 20:40:24 +0000 Subject: [reference model] Add shape operators - fixed up reshape conformance tests to use shape input instead of attribute - fixed up tile conformance tests to use shape input instead of attribute - fixed output and output rank of dim op - allow rank 0 and rank 1 tensors for tosa.shape values (for shape = {}) - added initialization of rank 0 const_shape tensors (for shape = {}) - Update conformance tests to use new rescale attributes Signed-off-by: Tai Ly Signed-off-by: Won Jeon Change-Id: I6cce0d2a9ab066fe20a2abf9d2cfde3eb3d8c18b --- reference_model/src/subgraph_traverser.cc | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) (limited to 'reference_model/src/subgraph_traverser.cc') diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 745213e..fae0b30 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -510,6 +510,12 @@ int SubgraphTraverser::allocateTensor(std::string name) FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); } + // set valid for constant tensors: + if ((ts->GetShape().empty() && ts->GetDtype() == DType_SHAPE)) + { + // corner case: const_shape {} has no data + tensor->setIsValid(); + } if (!ts->GetData().empty()) { if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy) @@ -545,13 +551,18 @@ int SubgraphTraverser::allocateTensor(std::string name) tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT48: - case DType_SHAPE: { + case DType_INT48: { std::vector i64_data; TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); } break; + case DType_SHAPE: { + std::vector i64_data; + TosaSerializationHandler::ConvertU8toI64(ts->GetData(), tensor->getElementCount(), i64_data); + tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); + } + break; case DType_FP16: { // Interpret f16 data as float std::vector f16_data; @@ -617,6 +628,10 @@ int SubgraphTraverser::allocateTensor(std::string name) EnumNameDType(ts->GetDtype())); } tensor->setIsValid(); + } + + if (tensor->getIsValid()) + { // Push ready consumers to the next node list for (auto gn : tensor->getConsumers()) { -- cgit v1.2.1