1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include <armnn/Optional.hpp>
#include <functional>
#include <map>
#include <stack>
#include <vector>
namespace armnnUtils
{
namespace
{
enum class NodeState
{
Visiting,
Visited,
};
template <typename TNodeId>
armnn::Optional<TNodeId> GetNextChild(TNodeId node,
std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
std::map<TNodeId, NodeState>& nodeStates)
{
for (TNodeId childNode : getIncomingEdges(node))
{
if (nodeStates.find(childNode) == nodeStates.end())
{
return childNode;
}
else
{
if (nodeStates.find(childNode)->second == NodeState::Visiting)
{
return childNode;
}
}
}
return {};
}
template<typename TNodeId>
bool TopologicallySort(
TNodeId initialNode,
std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
std::vector<TNodeId>& outSorted,
std::map<TNodeId, NodeState>& nodeStates)
{
std::stack<TNodeId> nodeStack;
// If the node is never visited we should search it
if (nodeStates.find(initialNode) == nodeStates.end())
{
nodeStack.push(initialNode);
}
while (!nodeStack.empty())
{
TNodeId current = nodeStack.top();
nodeStates[current] = NodeState::Visiting;
auto nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
if (nextChildOfCurrent)
{
TNodeId nextChild = nextChildOfCurrent.value();
// If the child has not been searched, add to the stack and iterate over this node
if (nodeStates.find(nextChild) == nodeStates.end())
{
nodeStack.push(nextChild);
continue;
}
// If we re-encounter a node being visited there is a cycle
if (nodeStates[nextChild] == NodeState::Visiting)
{
return false;
}
}
nodeStack.pop();
nodeStates[current] = NodeState::Visited;
outSorted.push_back(current);
}
return true;
}
}
// Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
// it must return the list of nodes which are required to come before it.
// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
// This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
template<typename TNodeId, typename TTargetNodes>
bool GraphTopologicalSort(
const TTargetNodes& targetNodes,
std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
std::vector<TNodeId>& outSorted)
{
outSorted.clear();
std::map<TNodeId, NodeState> nodeStates;
for (TNodeId targetNode : targetNodes)
{
if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))
{
return false;
}
}
return true;
}
}
|