From cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 7 Sep 2023 20:49:09 +0000 Subject: [reference_model] Support StatefulOps and the tests for CallOnceOp Signed-off-by: Jerry Ge Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2 --- reference_model/src/subgraph_traverser.cc | 130 +++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 12 deletions(-) (limited to 'reference_model/src/subgraph_traverser.cc') diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 186cb8b..745213e 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -64,6 +64,11 @@ SubgraphTraverser::~SubgraphTraverser() for (TosaReference::Tensor* t : tensors) { + if (t->getIsVariable() && parent_sgt) + { + // variable tensors are owned by top level sgt + continue; + } if (t->is_allocated()) { t->deallocate(); @@ -119,6 +124,51 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin return nullptr; } +int SubgraphTraverser::getNumVariableTensors() const +{ + return variableTensors.size(); +} + +TosaReference::Tensor* SubgraphTraverser::getVariableTensor(const unsigned int idx) const +{ + return variableTensors[idx]; +} + +// find variable tensor by name in top level sgt's @a variableTensors +TosaReference::Tensor* SubgraphTraverser::getVariableTensorByName(const std::string name) const +{ + // variable tensors are owned by top level sgt + if (parent_sgt) + { + return parent_sgt->getVariableTensorByName(name); + } + + for (auto t : variableTensors) + { + if (t->getName() == name) + { + return t; + } + } + + return nullptr; +} + +// add variable tensor to top level sgt's @a variableTensors +int SubgraphTraverser::registerVariableTensor(Tensor* tensor) +{ + SUBGRAPH_ERROR_IF(!tensor->getIsVariable(), + "SubgraphTraverser::registerVariableTensor(): tensor %s is not a variable", + tensor->getName().c_str()); + // variable tensors are owned by top level sgt + if (parent_sgt) + { + return parent_sgt->registerVariableTensor(tensor); + } + variableTensors.push_back(tensor); + return 0; +} + int SubgraphTraverser::initializeGraph() { int idx = 0; @@ -321,19 +371,18 @@ int SubgraphTraverser::initializeGraph() non_const_node_vec.push_back(node); } + // Bug fix: add the ready node in main block for evaluation + if (node->hasAllInputsReady() && !node->getOnNextNodeList() && !node->getEvaluated()) + { + addToNextNodeList(node); + } + idx++; } for (auto ts : block->GetTensors()) { - DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); - TosaReference::Tensor* tensor = - TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); - - SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", - ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); - - addTensor(tensor); + addTensor(ts); } DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); @@ -406,6 +455,22 @@ int SubgraphTraverser::allocateInputTensors() this->allocateTensor(input_tensor_name); } + // allocate variable tensors if not already allocated + for (auto ts : block->GetTensors()) + { + if (ts->GetVariable()) + { + TosaReference::Tensor* tensor = findTensorByName(ts->GetName()); + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateInputTensors(): can't find tensor %s.", + ts->GetName().c_str()); + if (!tensor->is_allocated()) + { + DEBUG_INFO(GT, "Is a VariableTensor %s", ts->GetName().c_str()); + this->allocateTensor(ts->GetName()); + } + } + } + return 0; } @@ -447,6 +512,8 @@ int SubgraphTraverser::allocateTensor(std::string name) if (!ts->GetData().empty()) { + if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy) + return 0; DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str()); auto serialization_dtype = ts->GetDtype(); switch (serialization_dtype) @@ -549,8 +616,16 @@ int SubgraphTraverser::allocateTensor(std::string name) SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.", EnumNameDType(ts->GetDtype())); } + tensor->setIsValid(); + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated()) + { + addToNextNodeList(gn); + } + } } - return 0; } @@ -619,6 +694,8 @@ int SubgraphTraverser::evaluateNextNode() return 1; } + currNode->setEvaluated(); + // free input tensor if all of its consumers have all of their outputs ready and it's not block's output for (auto tensor : currNode->getInputs()) { @@ -631,6 +708,12 @@ int SubgraphTraverser::evaluateNextNode() continue; } + if (tensor->getIsVariable()) + { + // if tensor is a Variable, we cannot free it + continue; + } + for (auto node : tensor->getConsumers()) { // If the node is inside a loop, the input tensor is still needed @@ -660,7 +743,7 @@ int SubgraphTraverser::evaluateNextNode() { for (GraphNode* node : tensor->getConsumers()) { - if (!node->getOnNextNodeList() && node->hasAllInputsReady()) + if (!node->getOnNextNodeList() && node->hasAllInputsReady() && !node->getEvaluated()) { addToNextNodeList(node); } @@ -716,8 +799,31 @@ int SubgraphTraverser::clearAllNodeMarkings() return false; } -int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor) +int SubgraphTraverser::addTensor(const TosaSerializationTensor* ts) { + TosaReference::Tensor* tensor = nullptr; + + // variable tensors are shared: make new tensor only if not found + if (ts->GetVariable()) + { + tensor = getVariableTensorByName(ts->GetName()); + } + + if (!tensor) + { + DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); + tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", + ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); + + if (ts->GetVariable()) + { + tensor->setIsVariable(); + registerVariableTensor(tensor); + } + } + // Enforce no duplicate tensors/tensor names // O(N), but the number of tensors is small for (TosaReference::Tensor* currTensor : tensors) @@ -751,7 +857,7 @@ int SubgraphTraverser::addNode(GraphNode* newNode) { if (currNode == newNode) { - FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph"); + FATAL_ERROR("SubgraphTraverser::addNode(): duplicate node being added to graph"); return 1; } } -- cgit v1.2.1