aboutsummaryrefslogtreecommitdiff
path: root/src/graph/Graph.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/Graph.cpp')
-rw-r--r--src/graph/Graph.cpp38
1 files changed, 21 insertions, 17 deletions
diff --git a/src/graph/Graph.cpp b/src/graph/Graph.cpp
index 4ce53589d4..3ae83f2e80 100644
--- a/src/graph/Graph.cpp
+++ b/src/graph/Graph.cpp
@@ -34,24 +34,24 @@ Graph::Graph(GraphID id, std::string name)
bool Graph::remove_node(NodeID nid)
{
- if(nid >= _nodes.size())
+ if (nid >= _nodes.size())
{
return false;
}
std::unique_ptr<INode> &node = _nodes[nid];
- if(node)
+ if (node)
{
// Remove input connections
- for(auto &input_eid : node->_input_edges)
+ for (auto &input_eid : node->_input_edges)
{
remove_connection(input_eid);
}
// Remove output connections
std::set<EdgeID> output_edges_copy = node->output_edges();
- for(auto &output_eid : output_edges_copy)
+ for (auto &output_eid : output_edges_copy)
{
remove_connection(output_eid);
}
@@ -71,8 +71,10 @@ EdgeID Graph::add_connection(NodeID source, size_t source_idx, NodeID sink, size
arm_compute::lock_guard<arm_compute::Mutex> lock(_mtx);
// Check if node index is valid, if node exists and finally if the connection index is valid
- ARM_COMPUTE_ERROR_ON((source >= _nodes.size()) || (_nodes[source] == nullptr) || (source_idx >= _nodes[source]->num_outputs()));
- ARM_COMPUTE_ERROR_ON((sink >= _nodes.size()) || (_nodes[sink] == nullptr) || (sink_idx >= _nodes[sink]->num_inputs()));
+ ARM_COMPUTE_ERROR_ON((source >= _nodes.size()) || (_nodes[source] == nullptr) ||
+ (source_idx >= _nodes[source]->num_outputs()));
+ ARM_COMPUTE_ERROR_ON((sink >= _nodes.size()) || (_nodes[sink] == nullptr) ||
+ (sink_idx >= _nodes[sink]->num_inputs()));
// Get nodes
std::unique_ptr<INode> &source_node = _nodes[source];
@@ -80,23 +82,25 @@ EdgeID Graph::add_connection(NodeID source, size_t source_idx, NodeID sink, size
// Check for duplicate connections (Check only sink node)
Edge *sink_node_edge = sink_node->input_edge(sink_idx);
- if((sink_node_edge != nullptr) && (sink_node_edge->producer_id() == source) && (sink_node_edge->producer_idx() == source_idx)
- && (sink_node_edge->consumer_id() == sink) && (sink_node_edge->consumer_idx() == sink_idx))
+ if ((sink_node_edge != nullptr) && (sink_node_edge->producer_id() == source) &&
+ (sink_node_edge->producer_idx() == source_idx) && (sink_node_edge->consumer_id() == sink) &&
+ (sink_node_edge->consumer_idx() == sink_idx))
{
return sink_node_edge->id();
}
// Check if there is already a tensor associated with output if not create one
TensorID tid = source_node->output_id(source_idx);
- if(tid == NullTensorID)
+ if (tid == NullTensorID)
{
tid = create_tensor();
}
std::unique_ptr<Tensor> &tensor = _tensors[tid];
// Create connections
- EdgeID eid = _edges.size();
- auto connection = std::make_unique<Edge>(eid, source_node.get(), source_idx, sink_node.get(), sink_idx, tensor.get());
+ EdgeID eid = _edges.size();
+ auto connection =
+ std::make_unique<Edge>(eid, source_node.get(), source_idx, sink_node.get(), sink_idx, tensor.get());
_edges.push_back(std::move(connection));
// Add connections to source and sink nodes
@@ -117,7 +121,7 @@ EdgeID Graph::add_connection(NodeID source, size_t source_idx, NodeID sink, size
bool Graph::remove_connection(EdgeID eid)
{
- if(eid >= _edges.size())
+ if (eid >= _edges.size())
{
return false;
}
@@ -125,22 +129,22 @@ bool Graph::remove_connection(EdgeID eid)
std::unique_ptr<Edge> &edge = _edges[eid];
// Remove node connections
- if(edge != nullptr)
+ if (edge != nullptr)
{
// Get tensor bound to the edge
- if(edge->tensor() != nullptr)
+ if (edge->tensor() != nullptr)
{
edge->tensor()->unbind_edge(eid);
}
// Remove edges from source node
- if(edge->producer() != nullptr)
+ if (edge->producer() != nullptr)
{
edge->producer()->_output_edges.erase(eid);
}
// Remove edges from sink node
- if((edge->consumer() != nullptr) && (edge->consumer_idx() < edge->consumer()->_input_edges.size()))
+ if ((edge->consumer() != nullptr) && (edge->consumer_idx() < edge->consumer()->_input_edges.size()))
{
edge->consumer()->_input_edges[edge->consumer_idx()] = EmptyEdgeID;
}
@@ -231,4 +235,4 @@ Tensor *Graph::tensor(TensorID id)
return (id >= _tensors.size()) ? nullptr : _tensors[id].get();
}
} // namespace graph
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute