aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc138
1 files changed, 109 insertions, 29 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 112e641..8867ada 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,13 +37,14 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
+SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt)
{
- graph_status = GraphStatus::TOSA_VALID;
+ graph_status = GraphStatus::TOSA_VALID;
block = _block;
- tsh = _tsh;
+ tsh = _tsh;
+ parent_sgt = _parent_sgt;
tensors.clear();
nodes.clear();
nextNodeList.clear();
@@ -120,6 +121,17 @@ int SubgraphTraverser::initializeGraph()
{
int idx = 0;
+ std::vector<TosaSerializationTensor*> ser_tensor_vec;
+ // Get all the serialized tensors from TosaSerializationHandler.
+ for (auto block: tsh->GetMainRegion()->GetBlocks())
+ {
+ for (auto ser_tensor : block->GetTensors())
+ {
+ ser_tensor_vec.push_back(ser_tensor);
+ }
+ }
+
+ std::vector<GraphNode*> non_const_node_vec;
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
@@ -159,7 +171,13 @@ int SubgraphTraverser::initializeGraph()
EnumNamesOp()[op->GetOp()], input_index);
std::string input_name = op->GetInputTensorNames()[input_index];
- TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name);
+ TosaSerializationTensor* input_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == input_name) {
+ input_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!input_tensor,
"SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
@@ -175,7 +193,13 @@ int SubgraphTraverser::initializeGraph()
"SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
EnumNamesOp()[op->GetOp()], weight_index);
std::string weight_name = op->GetInputTensorNames()[weight_index];
- TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name);
+ TosaSerializationTensor* weight_tensor = nullptr;
+ for (auto ser_tensor : ser_tensor_vec) {
+ if (ser_tensor->GetName() == weight_name) {
+ weight_tensor = ser_tensor;
+ }
+ }
+
SUBGRAPH_ERROR_IF(
!weight_tensor,
"SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
@@ -199,8 +223,19 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ GraphNode* node = nullptr;
+ if (this->parent_sgt) {
+ node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
+ input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ node->setInMainBlock(false);
+ } else {
+ node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ if (node) {
+ node->setInMainBlock(true);
+ }
+ }
+
if (!node)
{
if (weight_index == -1)
@@ -257,6 +292,8 @@ int SubgraphTraverser::initializeGraph()
if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
{
addToNextNodeList(node);
+ } else if (!node->getInMainBlock()) {
+ non_const_node_vec.push_back(node);
}
idx++;
@@ -271,7 +308,6 @@ int SubgraphTraverser::initializeGraph()
SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
- // update this->tensors
addTensor(tensor);
}
@@ -296,7 +332,7 @@ int SubgraphTraverser::initializeGraph()
for (auto& output_name : block->GetOutputs())
{
TosaReference::Tensor* tensor = findTensorByName(output_name);
- DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+ DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
if (tensor)
{
tensor->setIsSubgraphOutput();
@@ -314,6 +350,22 @@ int SubgraphTraverser::initializeGraph()
dumpNextNodeList(g_func_debug.func_debug_file);
}
+ // If the node is not in mainblock and not const
+ for (auto node : non_const_node_vec) {
+ bool all_inputs_from_parent = true;
+ for (std::string& name : node->getInputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t->getIsParentGraphOutput()) {
+ all_inputs_from_parent = false;
+ }
+ }
+ // In the children block, when a node has all its inputs from parent
+ // block, we have to manually add this node to the evaluation list
+ if (all_inputs_from_parent && !node->getOnNextNodeList()) {
+ addToNextNodeList(node);
+ }
+ }
return 0;
}
@@ -510,29 +562,40 @@ int SubgraphTraverser::evaluateNextNode()
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
- for (auto tensor : currNode->getInputs())
- {
- bool in_use = false;
- for (auto node : tensor->getConsumers())
+ if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks
+ for (auto tensor : currNode->getInputs())
{
- if (!node->hasAllOutputsReady())
+ bool in_use = false;
+
+ auto tensor_check = findTensorByName(tensor->getName());
+ if (tensor_check->getIsParentGraphOutput()) {
+ // if it's parent's block output tensor, we can't free it
+ continue;
+ }
+
+ for (auto node : tensor->getConsumers())
{
- in_use = true;
+ // If the node is inside a loop, the input tensor is still needed
+ if (!node->hasAllOutputsReady())
+ {
+ in_use = true;
+ }
+
}
- }
- for (auto name : block->GetOutputs())
- {
- if (name == tensor->getName())
+ for (auto name : block->GetOutputs())
{
- in_use = true;
+ if (name == tensor->getName())
+ {
+ in_use = true;
+ }
+ }
+
+ if (!in_use)
+ {
+ tensor->deallocate();
}
- }
- if (!in_use)
- {
- tensor->deallocate();
}
}
-
// Search the output tensors of this node to see if
// there are now new ready nodes available from completing this node
for (TosaReference::Tensor* tensor : currNode->getOutputs())
@@ -642,17 +705,35 @@ int SubgraphTraverser::addNode(GraphNode* newNode)
TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
{
+ TosaReference::Tensor* res_tensor = nullptr;
+
for (TosaReference::Tensor* currTensor : tensors)
{
if (currTensor->getName() == name)
{
- return currTensor;
+ res_tensor = currTensor;
+ return res_tensor;
}
}
- WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ if (parent_sgt)
+ {
+ for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
+ {
+ if (currTensor->getName() == name)
+ {
+ res_tensor = currTensor;
+ res_tensor->setIsParentGraphOutput();
+ }
+ }
+ }
- return nullptr;
+ if (!res_tensor)
+ {
+ WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
+ return nullptr;
+ }
+ return res_tensor;
}
int SubgraphTraverser::linkTensorsAndNodes()
@@ -704,7 +785,6 @@ int SubgraphTraverser::validateGraph()
for (TosaReference::Tensor* currTensor : tensors)
{
-
// It's okay for block input tensor not being consumed by operators.
// This is common in control flow op execution.
if (!currTensor->getIsSubgraphInput())