aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/GraphTopologicalSort.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/GraphTopologicalSort.hpp')
-rw-r--r--src/armnnUtils/GraphTopologicalSort.hpp90
1 files changed, 90 insertions, 0 deletions
diff --git a/src/armnnUtils/GraphTopologicalSort.hpp b/src/armnnUtils/GraphTopologicalSort.hpp
new file mode 100644
index 0000000000..f455289567
--- /dev/null
+++ b/src/armnnUtils/GraphTopologicalSort.hpp
@@ -0,0 +1,90 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include <boost/assert.hpp>
+
+#include <functional>
+#include <map>
+#include <vector>
+
+namespace armnnUtils
+{
+
+namespace
+{
+
+enum class NodeState
+{
+ Visiting,
+ Visited,
+};
+
+template<typename TNodeId>
+bool Visit(
+ TNodeId current,
+ std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
+ std::vector<TNodeId>& outSorted,
+ std::map<TNodeId, NodeState>& nodeStates)
+{
+ auto currentStateIt = nodeStates.find(current);
+ if (currentStateIt != nodeStates.end())
+ {
+ if (currentStateIt->second == NodeState::Visited)
+ {
+ return true;
+ }
+ if (currentStateIt->second == NodeState::Visiting)
+ {
+ return false;
+ }
+ else
+ {
+ BOOST_ASSERT(false);
+ }
+ }
+
+ nodeStates[current] = NodeState::Visiting;
+
+ for (TNodeId inputNode : getIncomingEdges(current))
+ {
+ Visit(inputNode, getIncomingEdges, outSorted, nodeStates);
+ }
+
+ nodeStates[current] = NodeState::Visited;
+
+ outSorted.push_back(current);
+ return true;
+}
+
+}
+
+// Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
+// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
+// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
+// it must return the list of nodes which are required to come before it.
+// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
+// The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
+template<typename TNodeId, typename TTargetNodes>
+bool GraphTopologicalSort(
+ const TTargetNodes& targetNodes,
+ std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
+ std::vector<TNodeId>& outSorted)
+{
+ outSorted.clear();
+ std::map<TNodeId, NodeState> nodeStates;
+
+ for (TNodeId targetNode : targetNodes)
+ {
+ if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates))
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} \ No newline at end of file