diff options
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 138 |
1 files changed, 109 insertions, 29 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 112e641..8867ada 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,13 +37,14 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh) +SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt) { - graph_status = GraphStatus::TOSA_VALID; + graph_status = GraphStatus::TOSA_VALID; block = _block; - tsh = _tsh; + tsh = _tsh; + parent_sgt = _parent_sgt; tensors.clear(); nodes.clear(); nextNodeList.clear(); @@ -120,6 +121,17 @@ int SubgraphTraverser::initializeGraph() { int idx = 0; + std::vector<TosaSerializationTensor*> ser_tensor_vec; + // Get all the serialized tensors from TosaSerializationHandler. + for (auto block: tsh->GetMainRegion()->GetBlocks()) + { + for (auto ser_tensor : block->GetTensors()) + { + ser_tensor_vec.push_back(ser_tensor); + } + } + + std::vector<GraphNode*> non_const_node_vec; for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode @@ -159,7 +171,13 @@ int SubgraphTraverser::initializeGraph() EnumNamesOp()[op->GetOp()], input_index); std::string input_name = op->GetInputTensorNames()[input_index]; - TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name); + TosaSerializationTensor* input_tensor = nullptr; + for (auto ser_tensor : ser_tensor_vec) { + if (ser_tensor->GetName() == input_name) { + input_tensor = ser_tensor; + } + } + SUBGRAPH_ERROR_IF( !input_tensor, "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler", @@ -175,7 +193,13 @@ int SubgraphTraverser::initializeGraph() "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()], weight_index); std::string weight_name = op->GetInputTensorNames()[weight_index]; - TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name); + TosaSerializationTensor* weight_tensor = nullptr; + for (auto ser_tensor : ser_tensor_vec) { + if (ser_tensor->GetName() == weight_name) { + weight_tensor = ser_tensor; + } + } + SUBGRAPH_ERROR_IF( !weight_tensor, "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler", @@ -199,8 +223,19 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, + GraphNode* node = nullptr; + if (this->parent_sgt) { + node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, + input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + node->setInMainBlock(false); + } else { + node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + if (node) { + node->setInMainBlock(true); + } + } + if (!node) { if (weight_index == -1) @@ -257,6 +292,8 @@ int SubgraphTraverser::initializeGraph() if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList()) { addToNextNodeList(node); + } else if (!node->getInMainBlock()) { + non_const_node_vec.push_back(node); } idx++; @@ -271,7 +308,6 @@ int SubgraphTraverser::initializeGraph() SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); - // update this->tensors addTensor(tensor); } @@ -296,7 +332,7 @@ int SubgraphTraverser::initializeGraph() for (auto& output_name : block->GetOutputs()) { TosaReference::Tensor* tensor = findTensorByName(output_name); - DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); + DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str()); if (tensor) { tensor->setIsSubgraphOutput(); @@ -314,6 +350,22 @@ int SubgraphTraverser::initializeGraph() dumpNextNodeList(g_func_debug.func_debug_file); } + // If the node is not in mainblock and not const + for (auto node : non_const_node_vec) { + bool all_inputs_from_parent = true; + for (std::string& name : node->getInputNames()) + { + TosaReference::Tensor* t = findTensorByName(name); + if (!t->getIsParentGraphOutput()) { + all_inputs_from_parent = false; + } + } + // In the children block, when a node has all its inputs from parent + // block, we have to manually add this node to the evaluation list + if (all_inputs_from_parent && !node->getOnNextNodeList()) { + addToNextNodeList(node); + } + } return 0; } @@ -510,29 +562,40 @@ int SubgraphTraverser::evaluateNextNode() } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output - for (auto tensor : currNode->getInputs()) - { - bool in_use = false; - for (auto node : tensor->getConsumers()) + if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks + for (auto tensor : currNode->getInputs()) { - if (!node->hasAllOutputsReady()) + bool in_use = false; + + auto tensor_check = findTensorByName(tensor->getName()); + if (tensor_check->getIsParentGraphOutput()) { + // if it's parent's block output tensor, we can't free it + continue; + } + + for (auto node : tensor->getConsumers()) { - in_use = true; + // If the node is inside a loop, the input tensor is still needed + if (!node->hasAllOutputsReady()) + { + in_use = true; + } + } - } - for (auto name : block->GetOutputs()) - { - if (name == tensor->getName()) + for (auto name : block->GetOutputs()) { - in_use = true; + if (name == tensor->getName()) + { + in_use = true; + } + } + + if (!in_use) + { + tensor->deallocate(); } - } - if (!in_use) - { - tensor->deallocate(); } } - // Search the output tensors of this node to see if // there are now new ready nodes available from completing this node for (TosaReference::Tensor* tensor : currNode->getOutputs()) @@ -642,17 +705,35 @@ int SubgraphTraverser::addNode(GraphNode* newNode) TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const { + TosaReference::Tensor* res_tensor = nullptr; + for (TosaReference::Tensor* currTensor : tensors) { if (currTensor->getName() == name) { - return currTensor; + res_tensor = currTensor; + return res_tensor; } } - WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str()); + if (parent_sgt) + { + for (TosaReference::Tensor* currTensor : parent_sgt->tensors) + { + if (currTensor->getName() == name) + { + res_tensor = currTensor; + res_tensor->setIsParentGraphOutput(); + } + } + } - return nullptr; + if (!res_tensor) + { + WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str()); + return nullptr; + } + return res_tensor; } int SubgraphTraverser::linkTensorsAndNodes() @@ -704,7 +785,6 @@ int SubgraphTraverser::validateGraph() for (TosaReference::Tensor* currTensor : tensors) { - // It's okay for block input tensor not being consumed by operators. // This is common in control flow op execution. if (!currTensor->getIsSubgraphInput()) |