aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/Graph.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/Graph.h')
-rw-r--r--arm_compute/graph/Graph.h13
1 files changed, 6 insertions, 7 deletions
diff --git a/arm_compute/graph/Graph.h b/arm_compute/graph/Graph.h
index 16f5f97986..2a776826e5 100644
--- a/arm_compute/graph/Graph.h
+++ b/arm_compute/graph/Graph.h
@@ -72,7 +72,7 @@ public:
* @tparam NT Node operation
* @tparam Ts Arguments to operation
*
- * @param args Node arguments
+ * @param[in] args Node arguments
*
* @return ID of the node
*/
@@ -114,9 +114,11 @@ public:
GraphID id() const;
/** Returns graph input nodes
*
- * @return vector containing the graph inputs
+ * @param[in] type Type of nodes to return
+ *
+ * @return vector containing the graph node of given type
*/
- const std::vector<NodeID> &inputs();
+ const std::vector<NodeID> &nodes(NodeType type);
/** Returns nodes of graph
*
* @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
@@ -238,10 +240,7 @@ inline NodeID Graph::add_node(Ts &&... args)
node->set_id(nid);
// Keep track of input nodes
- if(node->type() == NodeType::Input)
- {
- _tagged_nodes[NodeType::Input].push_back(nid);
- }
+ _tagged_nodes[node->type()].push_back(nid);
// Associate a new tensor with each output
for(auto &output : node->_outputs)