ArmNN
 20.02
GraphTopologicalSort.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Optional.hpp>
8 #include <boost/assert.hpp>
9 
10 #include <functional>
11 #include <map>
12 #include <stack>
13 #include <vector>
14 
15 
16 namespace armnnUtils
17 {
18 
19 namespace
20 {
21 
22 enum class NodeState
23 {
24  Visiting,
25  Visited,
26 };
27 
28 
29 template <typename TNodeId>
30 armnn::Optional<TNodeId> GetNextChild(TNodeId node,
31  std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
32  std::map<TNodeId, NodeState>& nodeStates)
33 {
34  for (TNodeId childNode : getIncomingEdges(node))
35  {
36  if (nodeStates.find(childNode) == nodeStates.end())
37  {
38  return childNode;
39  }
40  else
41  {
42  if (nodeStates.find(childNode)->second == NodeState::Visiting)
43  {
44  return childNode;
45  }
46  }
47  }
48 
49  return {};
50 }
51 
52 template<typename TNodeId>
53 bool TopologicallySort(
54  TNodeId initialNode,
55  std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
56  std::vector<TNodeId>& outSorted,
57  std::map<TNodeId, NodeState>& nodeStates)
58 {
59  std::stack<TNodeId> nodeStack;
60 
61  // If the node is never visited we should search it
62  if (nodeStates.find(initialNode) == nodeStates.end())
63  {
64  nodeStack.push(initialNode);
65  }
66 
67  while (!nodeStack.empty())
68  {
69  TNodeId current = nodeStack.top();
70 
71  nodeStates[current] = NodeState::Visiting;
72 
73  auto nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
74 
75  if (nextChildOfCurrent)
76  {
77  TNodeId nextChild = nextChildOfCurrent.value();
78 
79  // If the child has not been searched, add to the stack and iterate over this node
80  if (nodeStates.find(nextChild) == nodeStates.end())
81  {
82  nodeStack.push(nextChild);
83  continue;
84  }
85 
86  // If we re-encounter a node being visited there is a cycle
87  if (nodeStates[nextChild] == NodeState::Visiting)
88  {
89  return false;
90  }
91  }
92 
93  nodeStack.pop();
94 
95  nodeStates[current] = NodeState::Visited;
96  outSorted.push_back(current);
97  }
98 
99  return true;
100 }
101 
102 }
103 
104 // Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
105 // Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
106 // The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
107 // it must return the list of nodes which are required to come before it.
108 // "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
109 // This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
110 template<typename TNodeId, typename TTargetNodes>
112  const TTargetNodes& targetNodes,
113  std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
114  std::vector<TNodeId>& outSorted)
115 {
116  outSorted.clear();
117  std::map<TNodeId, NodeState> nodeStates;
118 
119  for (TNodeId targetNode : targetNodes)
120  {
121  if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))
122  {
123  return false;
124  }
125  }
126 
127  return true;
128 }
129 
130 }
bool GraphTopologicalSort(const TTargetNodes &targetNodes, std::function< std::vector< TNodeId >(TNodeId)> getIncomingEdges, std::vector< TNodeId > &outSorted)