From 9e94af8f10f0a21a117b3bc7ea42004844fdc3bb Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Thu, 27 Oct 2022 09:57:00 -0700 Subject: Reference model update for control flow operators support Rationale for making this change: - In the original design, for control flow operators like WhileOp, child blocks couldn't read the tensor variables (global consts) in the root level block, this patch added the machanism for child blocks to access their parent level block's tensors. - This change also relies on another serialization change on adding another layer of abtraction called Region: - Serialization patch: [region] Add TosaSerializationRegion to serialization_lib - Updated the corresponding python version of the serialization code: TosaSerializerRegion to python version of serialization_lib - This change also relies on the TOSA MLIR Translator change: Add RegionBuilder to TOSA MLIR Translator - Added the WhileOp related test cases: While, LSTM, GRU, RNN - Other related fixes Signed-off-by: Jerry Ge Change-Id: I13ae33628ad07e41d248e88652ce1328654694ab --- reference_model/src/subgraph_traverser.cc | 138 +++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 29 deletions(-) (limited to 'reference_model/src/subgraph_traverser.cc') diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 112e641..8867ada 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,13 +37,14 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh) +SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt) { - graph_status = GraphStatus::TOSA_VALID; + graph_status = GraphStatus::TOSA_VALID; block = _block; - tsh = _tsh; + tsh = _tsh; + parent_sgt = _parent_sgt; tensors.clear(); nodes.clear(); nextNodeList.clear(); @@ -120,6 +121,17 @@ int SubgraphTraverser::initializeGraph() { int idx = 0; + std::vector ser_tensor_vec; + // Get all the serialized tensors from TosaSerializationHandler. + for (auto block: tsh->GetMainRegion()->GetBlocks()) + { + for (auto ser_tensor : block->GetTensors()) + { + ser_tensor_vec.push_back(ser_tensor); + } + } + + std::vector non_const_node_vec; for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode @@ -159,7 +171,13 @@ int SubgraphTraverser::initializeGraph() EnumNamesOp()[op->GetOp()], input_index); std::string input_name = op->GetInputTensorNames()[input_index]; - TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name); + TosaSerializationTensor* input_tensor = nullptr; + for (auto ser_tensor : ser_tensor_vec) { + if (ser_tensor->GetName() == input_name) { + input_tensor = ser_tensor; + } + } + SUBGRAPH_ERROR_IF( !input_tensor, "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler", @@ -175,7 +193,13 @@ int SubgraphTraverser::initializeGraph() "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()], weight_index); std::string weight_name = op->GetInputTensorNames()[weight_index]; - TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name); + TosaSerializationTensor* weight_tensor = nullptr; + for (auto ser_tensor : ser_tensor_vec) { + if (ser_tensor->GetName() == weight_name) { + weight_tensor = ser_tensor; + } + } + SUBGRAPH_ERROR_IF( !weight_tensor, "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler", @@ -199,8 +223,19 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, + GraphNode* node = nullptr; + if (this->parent_sgt) { + node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, + input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + node->setInMainBlock(false); + } else { + node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + if (node) { + node->setInMainBlock(true); + } + } + if (!node) { if (weight_index == -1) @@ -257,6 +292,8 @@ int SubgraphTraverser::initializeGraph() if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList()) { addToNextNodeList(node); + } else if (!node->getInMainBlock()) { + non_const_node_vec.push_back(node); } idx++; @@ -271,7 +308,6 @@ int SubgraphTraverser::initializeGraph() SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); - // update this->tensors addTensor(tensor); } @@ -296,7 +332,7 @@ int SubgraphTraverser::initializeGraph() for (auto& output_name : block->GetOutputs()) { TosaReference::Tensor* tensor = findTensorByName(output_name); - DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); + DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str()); if (tensor) { tensor->setIsSubgraphOutput(); @@ -314,6 +350,22 @@ int SubgraphTraverser::initializeGraph() dumpNextNodeList(g_func_debug.func_debug_file); } + // If the node is not in mainblock and not const + for (auto node : non_const_node_vec) { + bool all_inputs_from_parent = true; + for (std::string& name : node->getInputNames()) + { + TosaReference::Tensor* t = findTensorByName(name); + if (!t->getIsParentGraphOutput()) { + all_inputs_from_parent = false; + } + } + // In the children block, when a node has all its inputs from parent + // block, we have to manually add this node to the evaluation list + if (all_inputs_from_parent && !node->getOnNextNodeList()) { + addToNextNodeList(node); + } + } return 0; } @@ -510,29 +562,40 @@ int SubgraphTraverser::evaluateNextNode() } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output - for (auto tensor : currNode->getInputs()) - { - bool in_use = false; - for (auto node : tensor->getConsumers()) + if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks + for (auto tensor : currNode->getInputs()) { - if (!node->hasAllOutputsReady()) + bool in_use = false; + + auto tensor_check = findTensorByName(tensor->getName()); + if (tensor_check->getIsParentGraphOutput()) { + // if it's parent's block output tensor, we can't free it + continue; + } + + for (auto node : tensor->getConsumers()) { - in_use = true; + // If the node is inside a loop, the input tensor is still needed + if (!node->hasAllOutputsReady()) + { + in_use = true; + } + } - } - for (auto name : block->GetOutputs()) - { - if (name == tensor->getName()) + for (auto name : block->GetOutputs()) { - in_use = true; + if (name == tensor->getName()) + { + in_use = true; + } + } + + if (!in_use) + { + tensor->deallocate(); } - } - if (!in_use) - { - tensor->deallocate(); } } - // Search the output tensors of this node to see if // there are now new ready nodes available from completing this node for (TosaReference::Tensor* tensor : currNode->getOutputs()) @@ -642,17 +705,35 @@ int SubgraphTraverser::addNode(GraphNode* newNode) TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const { + TosaReference::Tensor* res_tensor = nullptr; + for (TosaReference::Tensor* currTensor : tensors) { if (currTensor->getName() == name) { - return currTensor; + res_tensor = currTensor; + return res_tensor; } } - WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str()); + if (parent_sgt) + { + for (TosaReference::Tensor* currTensor : parent_sgt->tensors) + { + if (currTensor->getName() == name) + { + res_tensor = currTensor; + res_tensor->setIsParentGraphOutput(); + } + } + } - return nullptr; + if (!res_tensor) + { + WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str()); + return nullptr; + } + return res_tensor; } int SubgraphTraverser::linkTensorsAndNodes() @@ -704,7 +785,6 @@ int SubgraphTraverser::validateGraph() for (TosaReference::Tensor* currTensor : tensors) { - // It's okay for block input tensor not being consumed by operators. // This is common in control flow op execution. if (!currTensor->getIsSubgraphInput()) -- cgit v1.2.1