diff options
author | Tai Ly <tai.ly@arm.com> | 2023-09-07 20:49:09 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-09-15 18:10:01 +0000 |
commit | cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 (patch) | |
tree | aff6bab02c36c095a62381ac8f68d185bdccbe73 /reference_model/src/subgraph_traverser.cc | |
parent | 00f55bf46fe36bebe44e1365becbeb1e0d9e90c9 (diff) | |
download | reference_model-cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06.tar.gz |
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 130 |
1 files changed, 118 insertions, 12 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 186cb8b..745213e 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -64,6 +64,11 @@ SubgraphTraverser::~SubgraphTraverser() for (TosaReference::Tensor* t : tensors) { + if (t->getIsVariable() && parent_sgt) + { + // variable tensors are owned by top level sgt + continue; + } if (t->is_allocated()) { t->deallocate(); @@ -119,6 +124,51 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin return nullptr; } +int SubgraphTraverser::getNumVariableTensors() const +{ + return variableTensors.size(); +} + +TosaReference::Tensor* SubgraphTraverser::getVariableTensor(const unsigned int idx) const +{ + return variableTensors[idx]; +} + +// find variable tensor by name in top level sgt's @a variableTensors +TosaReference::Tensor* SubgraphTraverser::getVariableTensorByName(const std::string name) const +{ + // variable tensors are owned by top level sgt + if (parent_sgt) + { + return parent_sgt->getVariableTensorByName(name); + } + + for (auto t : variableTensors) + { + if (t->getName() == name) + { + return t; + } + } + + return nullptr; +} + +// add variable tensor to top level sgt's @a variableTensors +int SubgraphTraverser::registerVariableTensor(Tensor* tensor) +{ + SUBGRAPH_ERROR_IF(!tensor->getIsVariable(), + "SubgraphTraverser::registerVariableTensor(): tensor %s is not a variable", + tensor->getName().c_str()); + // variable tensors are owned by top level sgt + if (parent_sgt) + { + return parent_sgt->registerVariableTensor(tensor); + } + variableTensors.push_back(tensor); + return 0; +} + int SubgraphTraverser::initializeGraph() { int idx = 0; @@ -321,19 +371,18 @@ int SubgraphTraverser::initializeGraph() non_const_node_vec.push_back(node); } + // Bug fix: add the ready node in main block for evaluation + if (node->hasAllInputsReady() && !node->getOnNextNodeList() && !node->getEvaluated()) + { + addToNextNodeList(node); + } + idx++; } for (auto ts : block->GetTensors()) { - DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); - TosaReference::Tensor* tensor = - TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); - - SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", - ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); - - addTensor(tensor); + addTensor(ts); } DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); @@ -406,6 +455,22 @@ int SubgraphTraverser::allocateInputTensors() this->allocateTensor(input_tensor_name); } + // allocate variable tensors if not already allocated + for (auto ts : block->GetTensors()) + { + if (ts->GetVariable()) + { + TosaReference::Tensor* tensor = findTensorByName(ts->GetName()); + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateInputTensors(): can't find tensor %s.", + ts->GetName().c_str()); + if (!tensor->is_allocated()) + { + DEBUG_INFO(GT, "Is a VariableTensor %s", ts->GetName().c_str()); + this->allocateTensor(ts->GetName()); + } + } + } + return 0; } @@ -447,6 +512,8 @@ int SubgraphTraverser::allocateTensor(std::string name) if (!ts->GetData().empty()) { + if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy) + return 0; DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str()); auto serialization_dtype = ts->GetDtype(); switch (serialization_dtype) @@ -549,8 +616,16 @@ int SubgraphTraverser::allocateTensor(std::string name) SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.", EnumNameDType(ts->GetDtype())); } + tensor->setIsValid(); + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated()) + { + addToNextNodeList(gn); + } + } } - return 0; } @@ -619,6 +694,8 @@ int SubgraphTraverser::evaluateNextNode() return 1; } + currNode->setEvaluated(); + // free input tensor if all of its consumers have all of their outputs ready and it's not block's output for (auto tensor : currNode->getInputs()) { @@ -631,6 +708,12 @@ int SubgraphTraverser::evaluateNextNode() continue; } + if (tensor->getIsVariable()) + { + // if tensor is a Variable, we cannot free it + continue; + } + for (auto node : tensor->getConsumers()) { // If the node is inside a loop, the input tensor is still needed @@ -660,7 +743,7 @@ int SubgraphTraverser::evaluateNextNode() { for (GraphNode* node : tensor->getConsumers()) { - if (!node->getOnNextNodeList() && node->hasAllInputsReady()) + if (!node->getOnNextNodeList() && node->hasAllInputsReady() && !node->getEvaluated()) { addToNextNodeList(node); } @@ -716,8 +799,31 @@ int SubgraphTraverser::clearAllNodeMarkings() return false; } -int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor) +int SubgraphTraverser::addTensor(const TosaSerializationTensor* ts) { + TosaReference::Tensor* tensor = nullptr; + + // variable tensors are shared: make new tensor only if not found + if (ts->GetVariable()) + { + tensor = getVariableTensorByName(ts->GetName()); + } + + if (!tensor) + { + DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); + tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", + ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); + + if (ts->GetVariable()) + { + tensor->setIsVariable(); + registerVariableTensor(tensor); + } + } + // Enforce no duplicate tensors/tensor names // O(N), but the number of tensors is small for (TosaReference::Tensor* currTensor : tensors) @@ -751,7 +857,7 @@ int SubgraphTraverser::addNode(GraphNode* newNode) { if (currNode == newNode) { - FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph"); + FATAL_ERROR("SubgraphTraverser::addNode(): duplicate node being added to graph"); return 1; } } |