diff options
Diffstat (limited to 'arm_compute/graph/Graph.h')
-rw-r--r-- | arm_compute/graph/Graph.h | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/arm_compute/graph/Graph.h b/arm_compute/graph/Graph.h index 16f5f97986..2a776826e5 100644 --- a/arm_compute/graph/Graph.h +++ b/arm_compute/graph/Graph.h @@ -72,7 +72,7 @@ public: * @tparam NT Node operation * @tparam Ts Arguments to operation * - * @param args Node arguments + * @param[in] args Node arguments * * @return ID of the node */ @@ -114,9 +114,11 @@ public: GraphID id() const; /** Returns graph input nodes * - * @return vector containing the graph inputs + * @param[in] type Type of nodes to return + * + * @return vector containing the graph node of given type */ - const std::vector<NodeID> &inputs(); + const std::vector<NodeID> &nodes(NodeType type); /** Returns nodes of graph * * @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph @@ -238,10 +240,7 @@ inline NodeID Graph::add_node(Ts &&... args) node->set_id(nid); // Keep track of input nodes - if(node->type() == NodeType::Input) - { - _tagged_nodes[NodeType::Input].push_back(nid); - } + _tagged_nodes[node->type()].push_back(nid); // Associate a new tensor with each output for(auto &output : node->_outputs) |