aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc19
1 files changed, 17 insertions, 2 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 745213e..fae0b30 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -510,6 +510,12 @@ int SubgraphTraverser::allocateTensor(std::string name)
FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
}
+ // set valid for constant tensors:
+ if ((ts->GetShape().empty() && ts->GetDtype() == DType_SHAPE))
+ {
+ // corner case: const_shape {} has no data
+ tensor->setIsValid();
+ }
if (!ts->GetData().empty())
{
if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy)
@@ -545,13 +551,18 @@ int SubgraphTraverser::allocateTensor(std::string name)
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT48:
- case DType_SHAPE: {
+ case DType_INT48: {
std::vector<int64_t> i64_data;
TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
}
break;
+ case DType_SHAPE: {
+ std::vector<int64_t> i64_data;
+ TosaSerializationHandler::ConvertU8toI64(ts->GetData(), tensor->getElementCount(), i64_data);
+ tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
+ }
+ break;
case DType_FP16: {
// Interpret f16 data as float
std::vector<float> f16_data;
@@ -617,6 +628,10 @@ int SubgraphTraverser::allocateTensor(std::string name)
EnumNameDType(ts->GetDtype()));
}
tensor->setIsValid();
+ }
+
+ if (tensor->getIsValid())
+ {
// Push ready consumers to the next node list
for (auto gn : tensor->getConsumers())
{