diff options
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 745213e..fae0b30 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -510,6 +510,12 @@ int SubgraphTraverser::allocateTensor(std::string name) FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); } + // set valid for constant tensors: + if ((ts->GetShape().empty() && ts->GetDtype() == DType_SHAPE)) + { + // corner case: const_shape {} has no data + tensor->setIsValid(); + } if (!ts->GetData().empty()) { if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy) @@ -545,13 +551,18 @@ int SubgraphTraverser::allocateTensor(std::string name) tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT48: - case DType_SHAPE: { + case DType_INT48: { std::vector<int64_t> i64_data; TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); } break; + case DType_SHAPE: { + std::vector<int64_t> i64_data; + TosaSerializationHandler::ConvertU8toI64(ts->GetData(), tensor->getElementCount(), i64_data); + tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); + } + break; case DType_FP16: { // Interpret f16 data as float std::vector<float> f16_data; @@ -617,6 +628,10 @@ int SubgraphTraverser::allocateTensor(std::string name) EnumNameDType(ts->GetDtype())); } tensor->setIsValid(); + } + + if (tensor->getIsValid()) + { // Push ready consumers to the next node list for (auto gn : tensor->getConsumers()) { |