aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-07 20:49:09 +0000
committerEric Kunze <eric.kunze@arm.com>2023-09-15 18:10:01 +0000
commitcf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 (patch)
treeaff6bab02c36c095a62381ac8f68d185bdccbe73 /reference_model/src/subgraph_traverser.cc
parent00f55bf46fe36bebe44e1365becbeb1e0d9e90c9 (diff)
downloadreference_model-cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06.tar.gz
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc130
1 files changed, 118 insertions, 12 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 186cb8b..745213e 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -64,6 +64,11 @@ SubgraphTraverser::~SubgraphTraverser()
for (TosaReference::Tensor* t : tensors)
{
+ if (t->getIsVariable() && parent_sgt)
+ {
+ // variable tensors are owned by top level sgt
+ continue;
+ }
if (t->is_allocated())
{
t->deallocate();
@@ -119,6 +124,51 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin
return nullptr;
}
+int SubgraphTraverser::getNumVariableTensors() const
+{
+ return variableTensors.size();
+}
+
+TosaReference::Tensor* SubgraphTraverser::getVariableTensor(const unsigned int idx) const
+{
+ return variableTensors[idx];
+}
+
+// find variable tensor by name in top level sgt's @a variableTensors
+TosaReference::Tensor* SubgraphTraverser::getVariableTensorByName(const std::string name) const
+{
+ // variable tensors are owned by top level sgt
+ if (parent_sgt)
+ {
+ return parent_sgt->getVariableTensorByName(name);
+ }
+
+ for (auto t : variableTensors)
+ {
+ if (t->getName() == name)
+ {
+ return t;
+ }
+ }
+
+ return nullptr;
+}
+
+// add variable tensor to top level sgt's @a variableTensors
+int SubgraphTraverser::registerVariableTensor(Tensor* tensor)
+{
+ SUBGRAPH_ERROR_IF(!tensor->getIsVariable(),
+ "SubgraphTraverser::registerVariableTensor(): tensor %s is not a variable",
+ tensor->getName().c_str());
+ // variable tensors are owned by top level sgt
+ if (parent_sgt)
+ {
+ return parent_sgt->registerVariableTensor(tensor);
+ }
+ variableTensors.push_back(tensor);
+ return 0;
+}
+
int SubgraphTraverser::initializeGraph()
{
int idx = 0;
@@ -321,19 +371,18 @@ int SubgraphTraverser::initializeGraph()
non_const_node_vec.push_back(node);
}
+ // Bug fix: add the ready node in main block for evaluation
+ if (node->hasAllInputsReady() && !node->getOnNextNodeList() && !node->getEvaluated())
+ {
+ addToNextNodeList(node);
+ }
+
idx++;
}
for (auto ts : block->GetTensors())
{
- DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
- TosaReference::Tensor* tensor =
- TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
-
- SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
- ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
-
- addTensor(tensor);
+ addTensor(ts);
}
DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
@@ -406,6 +455,22 @@ int SubgraphTraverser::allocateInputTensors()
this->allocateTensor(input_tensor_name);
}
+ // allocate variable tensors if not already allocated
+ for (auto ts : block->GetTensors())
+ {
+ if (ts->GetVariable())
+ {
+ TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateInputTensors(): can't find tensor %s.",
+ ts->GetName().c_str());
+ if (!tensor->is_allocated())
+ {
+ DEBUG_INFO(GT, "Is a VariableTensor %s", ts->GetName().c_str());
+ this->allocateTensor(ts->GetName());
+ }
+ }
+ }
+
return 0;
}
@@ -447,6 +512,8 @@ int SubgraphTraverser::allocateTensor(std::string name)
if (!ts->GetData().empty())
{
+ if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy)
+ return 0;
DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
auto serialization_dtype = ts->GetDtype();
switch (serialization_dtype)
@@ -549,8 +616,16 @@ int SubgraphTraverser::allocateTensor(std::string name)
SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
EnumNameDType(ts->GetDtype()));
}
+ tensor->setIsValid();
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
+ {
+ addToNextNodeList(gn);
+ }
+ }
}
-
return 0;
}
@@ -619,6 +694,8 @@ int SubgraphTraverser::evaluateNextNode()
return 1;
}
+ currNode->setEvaluated();
+
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
for (auto tensor : currNode->getInputs())
{
@@ -631,6 +708,12 @@ int SubgraphTraverser::evaluateNextNode()
continue;
}
+ if (tensor->getIsVariable())
+ {
+ // if tensor is a Variable, we cannot free it
+ continue;
+ }
+
for (auto node : tensor->getConsumers())
{
// If the node is inside a loop, the input tensor is still needed
@@ -660,7 +743,7 @@ int SubgraphTraverser::evaluateNextNode()
{
for (GraphNode* node : tensor->getConsumers())
{
- if (!node->getOnNextNodeList() && node->hasAllInputsReady())
+ if (!node->getOnNextNodeList() && node->hasAllInputsReady() && !node->getEvaluated())
{
addToNextNodeList(node);
}
@@ -716,8 +799,31 @@ int SubgraphTraverser::clearAllNodeMarkings()
return false;
}
-int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor)
+int SubgraphTraverser::addTensor(const TosaSerializationTensor* ts)
{
+ TosaReference::Tensor* tensor = nullptr;
+
+ // variable tensors are shared: make new tensor only if not found
+ if (ts->GetVariable())
+ {
+ tensor = getVariableTensorByName(ts->GetName());
+ }
+
+ if (!tensor)
+ {
+ DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
+ tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
+ ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
+
+ if (ts->GetVariable())
+ {
+ tensor->setIsVariable();
+ registerVariableTensor(tensor);
+ }
+ }
+
// Enforce no duplicate tensors/tensor names
// O(N), but the number of tensors is small
for (TosaReference::Tensor* currTensor : tensors)
@@ -751,7 +857,7 @@ int SubgraphTraverser::addNode(GraphNode* newNode)
{
if (currNode == newNode)
{
- FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph");
+ FATAL_ERROR("SubgraphTraverser::addNode(): duplicate node being added to graph");
return 1;
}
}