From 9c9c8dafe8f9a32bd70aee268cd537b93865a3ba Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Wed, 19 Jul 2023 23:08:16 +0000 Subject: Run clang-format and update copyright - Also added run clang-format to pre-commit runs Signed-off-by: Jerry Ge Change-Id: I4e59ac0afbaa30dce0773aa63d92a1a3b119e2f3 --- reference_model/src/subgraph_traverser.cc | 71 +++++++++++++++++++------------ 1 file changed, 44 insertions(+), 27 deletions(-) (limited to 'reference_model/src/subgraph_traverser.cc') diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 5f6fd01..a7ef5e9 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -14,8 +14,8 @@ // limitations under the License. #include "subgraph_traverser.h" -#include "tosa_model_types.h" #include "arith_util.h" +#include "tosa_model_types.h" #ifndef SUBGRAPH_ERROR_IF #define SUBGRAPH_ERROR_IF(COND, fmt, ...) \ @@ -37,13 +37,15 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt) +SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, + TosaSerializationHandler* _tsh, + SubgraphTraverser* _parent_sgt) { graph_status = GraphStatus::TOSA_VALID; - block = _block; + block = _block; - tsh = _tsh; + tsh = _tsh; parent_sgt = _parent_sgt; tensors.clear(); nodes.clear(); @@ -151,11 +153,11 @@ int SubgraphTraverser::initializeGraph() TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN; TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN; TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN; - uint32_t input_rank = 0; - uint32_t output_rank = 0; - uint32_t weight_rank = 0; - int32_t input_index = -1; - int32_t weight_index = -1; + uint32_t input_rank = 0; + uint32_t output_rank = 0; + uint32_t weight_rank = 0; + int32_t input_index = -1; + int32_t weight_index = -1; switch (op->GetOp()) { @@ -185,8 +187,10 @@ int SubgraphTraverser::initializeGraph() std::string input_name = op->GetInputTensorNames()[input_index]; TosaSerializationTensor* input_tensor = nullptr; - for (auto ser_tensor : ser_tensor_vec) { - if (ser_tensor->GetName() == input_name) { + for (auto ser_tensor : ser_tensor_vec) + { + if (ser_tensor->GetName() == input_name) + { input_tensor = ser_tensor; } } @@ -207,8 +211,10 @@ int SubgraphTraverser::initializeGraph() EnumNamesOp()[op->GetOp()], weight_index); std::string weight_name = op->GetInputTensorNames()[weight_index]; TosaSerializationTensor* weight_tensor = nullptr; - for (auto ser_tensor : ser_tensor_vec) { - if (ser_tensor->GetName() == weight_name) { + for (auto ser_tensor : ser_tensor_vec) + { + if (ser_tensor->GetName() == weight_name) + { weight_tensor = ser_tensor; } } @@ -237,14 +243,18 @@ int SubgraphTraverser::initializeGraph() EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); GraphNode* node = nullptr; - if (this->parent_sgt) { + if (this->parent_sgt) + { node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, - input_rank, output_dtype, output_rank, weight_dtype, weight_rank); + input_rank, output_dtype, output_rank, weight_dtype, weight_rank); node->setInMainBlock(false); - } else { - node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, - input_rank, output_dtype, output_rank, weight_dtype, weight_rank); - if (node) { + } + else + { + node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank, + output_dtype, output_rank, weight_dtype, weight_rank); + if (node) + { node->setInMainBlock(true); } } @@ -305,7 +315,9 @@ int SubgraphTraverser::initializeGraph() if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList()) { addToNextNodeList(node); - } else if (!node->getInMainBlock()) { + } + else if (!node->getInMainBlock()) + { non_const_node_vec.push_back(node); } @@ -364,18 +376,21 @@ int SubgraphTraverser::initializeGraph() } // If the node is not in mainblock and not const - for (auto node : non_const_node_vec) { + for (auto node : non_const_node_vec) + { bool all_inputs_from_parent = true; for (std::string& name : node->getInputNames()) { TosaReference::Tensor* t = findTensorByName(name); - if (!t->getIsParentGraphOutput()) { + if (!t->getIsParentGraphOutput()) + { all_inputs_from_parent = false; } } // In the children block, when a node has all its inputs from parent // block, we have to manually add this node to the evaluation list - if (all_inputs_from_parent && !node->getOnNextNodeList()) { + if (all_inputs_from_parent && !node->getOnNextNodeList()) + { addToNextNodeList(node); } } @@ -395,7 +410,8 @@ int SubgraphTraverser::allocateTensor() { if (dim <= 0) { - DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim); + DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), + dim); this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); return 1; } @@ -591,13 +607,15 @@ int SubgraphTraverser::evaluateNextNode() } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output - if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks + if (!currNode->getInMainBlock()) + { // we don't free it if the node is in main block and has nested blocks for (auto tensor : currNode->getInputs()) { bool in_use = false; auto tensor_check = findTensorByName(tensor->getName()); - if (tensor_check->getIsParentGraphOutput()) { + if (tensor_check->getIsParentGraphOutput()) + { // if it's parent's block output tensor, we can't free it continue; } @@ -609,7 +627,6 @@ int SubgraphTraverser::evaluateNextNode() { in_use = true; } - } for (auto name : block->GetOutputs()) { -- cgit v1.2.1