aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-07-19 23:08:16 +0000
committerJerry Ge <jerry.ge@arm.com>2023-07-25 22:49:17 +0000
commit9c9c8dafe8f9a32bd70aee268cd537b93865a3ba (patch)
treee94fc471261b9f72bef86033fbc76022f55d5de8 /reference_model/src/subgraph_traverser.cc
parentc1e13432b4a218781afd6b0171d4afff11730433 (diff)
downloadreference_model-9c9c8dafe8f9a32bd70aee268cd537b93865a3ba.tar.gz
Run clang-format and update copyright
- Also added run clang-format to pre-commit runs Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I4e59ac0afbaa30dce0773aa63d92a1a3b119e2f3
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc71
1 files changed, 44 insertions, 27 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 5f6fd01..a7ef5e9 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -14,8 +14,8 @@
// limitations under the License.
#include "subgraph_traverser.h"
-#include "tosa_model_types.h"
#include "arith_util.h"
+#include "tosa_model_types.h"
#ifndef SUBGRAPH_ERROR_IF
#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
@@ -37,13 +37,15 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt)
+SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block,
+ TosaSerializationHandler* _tsh,
+ SubgraphTraverser* _parent_sgt)
{
graph_status = GraphStatus::TOSA_VALID;
- block = _block;
+ block = _block;
- tsh = _tsh;
+ tsh = _tsh;
parent_sgt = _parent_sgt;
tensors.clear();
nodes.clear();
@@ -151,11 +153,11 @@ int SubgraphTraverser::initializeGraph()
TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN;
TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN;
TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN;
- uint32_t input_rank = 0;
- uint32_t output_rank = 0;
- uint32_t weight_rank = 0;
- int32_t input_index = -1;
- int32_t weight_index = -1;
+ uint32_t input_rank = 0;
+ uint32_t output_rank = 0;
+ uint32_t weight_rank = 0;
+ int32_t input_index = -1;
+ int32_t weight_index = -1;
switch (op->GetOp())
{
@@ -185,8 +187,10 @@ int SubgraphTraverser::initializeGraph()
std::string input_name = op->GetInputTensorNames()[input_index];
TosaSerializationTensor* input_tensor = nullptr;
- for (auto ser_tensor : ser_tensor_vec) {
- if (ser_tensor->GetName() == input_name) {
+ for (auto ser_tensor : ser_tensor_vec)
+ {
+ if (ser_tensor->GetName() == input_name)
+ {
input_tensor = ser_tensor;
}
}
@@ -207,8 +211,10 @@ int SubgraphTraverser::initializeGraph()
EnumNamesOp()[op->GetOp()], weight_index);
std::string weight_name = op->GetInputTensorNames()[weight_index];
TosaSerializationTensor* weight_tensor = nullptr;
- for (auto ser_tensor : ser_tensor_vec) {
- if (ser_tensor->GetName() == weight_name) {
+ for (auto ser_tensor : ser_tensor_vec)
+ {
+ if (ser_tensor->GetName() == weight_name)
+ {
weight_tensor = ser_tensor;
}
}
@@ -237,14 +243,18 @@ int SubgraphTraverser::initializeGraph()
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
GraphNode* node = nullptr;
- if (this->parent_sgt) {
+ if (this->parent_sgt)
+ {
node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
- input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
+ input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
node->setInMainBlock(false);
- } else {
- node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
- input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
- if (node) {
+ }
+ else
+ {
+ node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank,
+ output_dtype, output_rank, weight_dtype, weight_rank);
+ if (node)
+ {
node->setInMainBlock(true);
}
}
@@ -305,7 +315,9 @@ int SubgraphTraverser::initializeGraph()
if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
{
addToNextNodeList(node);
- } else if (!node->getInMainBlock()) {
+ }
+ else if (!node->getInMainBlock())
+ {
non_const_node_vec.push_back(node);
}
@@ -364,18 +376,21 @@ int SubgraphTraverser::initializeGraph()
}
// If the node is not in mainblock and not const
- for (auto node : non_const_node_vec) {
+ for (auto node : non_const_node_vec)
+ {
bool all_inputs_from_parent = true;
for (std::string& name : node->getInputNames())
{
TosaReference::Tensor* t = findTensorByName(name);
- if (!t->getIsParentGraphOutput()) {
+ if (!t->getIsParentGraphOutput())
+ {
all_inputs_from_parent = false;
}
}
// In the children block, when a node has all its inputs from parent
// block, we have to manually add this node to the evaluation list
- if (all_inputs_from_parent && !node->getOnNextNodeList()) {
+ if (all_inputs_from_parent && !node->getOnNextNodeList())
+ {
addToNextNodeList(node);
}
}
@@ -395,7 +410,8 @@ int SubgraphTraverser::allocateTensor()
{
if (dim <= 0)
{
- DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim);
+ DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(),
+ dim);
this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
return 1;
}
@@ -591,13 +607,15 @@ int SubgraphTraverser::evaluateNextNode()
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
- if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks
+ if (!currNode->getInMainBlock())
+ { // we don't free it if the node is in main block and has nested blocks
for (auto tensor : currNode->getInputs())
{
bool in_use = false;
auto tensor_check = findTensorByName(tensor->getName());
- if (tensor_check->getIsParentGraphOutput()) {
+ if (tensor_check->getIsParentGraphOutput())
+ {
// if it's parent's block output tensor, we can't free it
continue;
}
@@ -609,7 +627,6 @@ int SubgraphTraverser::evaluateNextNode()
{
in_use = true;
}
-
}
for (auto name : block->GetOutputs())
{