aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-10-27 09:57:00 -0700
committerEric Kunze <eric.kunze@arm.com>2023-01-13 19:09:21 +0000
commit9e94af8f10f0a21a117b3bc7ea42004844fdc3bb (patch)
tree868ab73bb67d4827963a4b43f28d8a8a49f50307 /reference_model/src/subgraph_traverser.cc
parentdd8d9c251db0fece6453d86116052ad7f3e2d697 (diff)
downloadreference_model-9e94af8f10f0a21a117b3bc7ea42004844fdc3bb.tar.gz
Reference model update for control flow operators support
Rationale for making this change: - In the original design, for control flow operators like WhileOp, child blocks couldn't read the tensor variables (global consts) in the root level block, this patch added the machanism for child blocks to access their parent level block's tensors. - This change also relies on another serialization change on adding another layer of abtraction called Region: - Serialization patch: [region] Add TosaSerializationRegion to serialization_lib - Updated the corresponding python version of the serialization code: TosaSerializerRegion to python version of serialization_lib - This change also relies on the TOSA MLIR Translator change: Add RegionBuilder to TOSA MLIR Translator - Added the WhileOp related test cases: While, LSTM, GRU, RNN - Other related fixes Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I13ae33628ad07e41d248e88652ce1328654694ab
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc138
1 files changed, 109 insertions, 29 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 112e641..8867ada 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,13 +37,14 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
+SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt)
{
- graph_status = GraphStatus::TOSA_VALID;
+ graph_status = GraphStatus::TOSA_VALID;
block = _block;
- tsh = _tsh;
+ tsh = _tsh;
+ parent_sgt = _parent_sgt;
tensors.clear();
nodes.clear();
nextNodeList.clear();
@@ -120,6 +121,17 @@ int SubgraphTraverser::initializeGraph()
{
int idx = 0;
+ std::vector<TosaSerializationTensor*> ser_tensor_vec;
+ // Get all the serialized tensors from TosaSerializationHandler.
+ for (auto block: tsh->GetMainRegion()->GetBlocks())
+ {
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
+ }
+
+ std::vector<GraphNode*> non_const_node_vec;
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
@@ -159,7 +171,13 @@ int SubgraphTraverser::initializeGraph()
EnumNamesOp()[op->GetOp()], input_index);
std::string input_name = op->GetInputTensorNames()[input_index];
- TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name);
+ TosaSerializationTensor* input_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == input_name) {
+ input_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!input_tensor,
"SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
@@ -175,7 +193,13 @@ int SubgraphTraverser::initializeGraph()
"SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
EnumNamesOp()[op->GetOp()], weight_index);
std::string weight_name = op->GetInputTensorNames()[weight_index];
- TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name);
+ TosaSerializationTensor* weight_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == weight_name) {
+ weight_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!weight_tensor,
"SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
@@ -199,8 +223,19 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ GraphNode* node = nullptr;
+ if (this->parent_sgt) {
+ node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ node->setInMainBlock(false);
+ } else {
+ node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ if (node) {
+ node->setInMainBlock(true);
+ }
+ }
+
if (!node)
{
if (weight_index == -1)
@@ -257,6 +292,8 @@ int SubgraphTraverser::initializeGraph()
if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
{
addToNextNodeList(node);
+ } else if (!node->getInMainBlock()) {
+ non_const_node_vec.push_back(node);
}
idx++;
@@ -271,7 +308,6 @@ int SubgraphTraverser::initializeGraph()
SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
- // update this->tensors
addTensor(tensor);
}
@@ -296,7 +332,7 @@ int SubgraphTraverser::initializeGraph()
for (auto& output_name : block->GetOutputs())
{
TosaReference::Tensor* tensor = findTensorByName(output_name);
- DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+ DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
if (tensor)
{
tensor->setIsSubgraphOutput();
@@ -314,6 +350,22 @@ int SubgraphTraverser::initializeGraph()
dumpNextNodeList(g_func_debug.func_debug_file);
}
+ // If the node is not in mainblock and not const
+ for (auto node : non_const_node_vec) {
+ bool all_inputs_from_parent = true;
+ for (std::string& name : node->getInputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t->getIsParentGraphOutput()) {
+ all_inputs_from_parent = false;
+ }
+ }
+ // In the children block, when a node has all its inputs from parent
+ // block, we have to manually add this node to the evaluation list
+ if (all_inputs_from_parent && !node->getOnNextNodeList()) {
+ addToNextNodeList(node);
+ }
+ }
return 0;
}
@@ -510,29 +562,40 @@ int SubgraphTraverser::evaluateNextNode()
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
- for (auto tensor : currNode->getInputs())
- {
- bool in_use = false;
- for (auto node : tensor->getConsumers())
+ if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks
+ for (auto tensor : currNode->getInputs())
{
- if (!node->hasAllOutputsReady())
+ bool in_use = false;
+
+ auto tensor_check = findTensorByName(tensor->getName());
+ if (tensor_check->getIsParentGraphOutput()) {
+ // if it's parent's block output tensor, we can't free it
+ continue;
+ }
+
+ for (auto node : tensor->getConsumers())
{
- in_use = true;
+ // If the node is inside a loop, the input tensor is still needed
+ if (!node->hasAllOutputsReady())
+ {
+ in_use = true;
+ }
+
}
- }
- for (auto name : block->GetOutputs())
- {
- if (name == tensor->getName())
+ for (auto name : block->GetOutputs())
{
- in_use = true;
+ if (name == tensor->getName())
+ {
+ in_use = true;
+ }
+ }
+
+ if (!in_use)
+ {
+ tensor->deallocate();
}
- }
- if (!in_use)
- {
- tensor->deallocate();
}
}
-
// Search the output tensors of this node to see if
// there are now new ready nodes available from completing this node
for (TosaReference::Tensor* tensor : currNode->getOutputs())
@@ -642,17 +705,35 @@ int SubgraphTraverser::addNode(GraphNode* newNode)
TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
{
+ TosaReference::Tensor* res_tensor = nullptr;
+
for (TosaReference::Tensor* currTensor : tensors)
{
if (currTensor->getName() == name)
{
- return currTensor;
+ res_tensor = currTensor;
+ return res_tensor;
}
}
- WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ if (parent_sgt)
+ {
+ for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
+ {
+ if (currTensor->getName() == name)
+ {
+ res_tensor = currTensor;
+ res_tensor->setIsParentGraphOutput();
+ }
+ }
+ }
- return nullptr;
+ if (!res_tensor)
+ {
+ WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ return nullptr;
+ }
+ return res_tensor;
}
int SubgraphTraverser::linkTensorsAndNodes()
@@ -704,7 +785,6 @@ int SubgraphTraverser::validateGraph()
for (TosaReference::Tensor* currTensor : tensors)
{
-
// It's okay for block input tensor not being consumed by operators.
// This is common in control flow op execution.
if (!currTensor->getIsSubgraphInput())