aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/GraphTopologicalSort.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnnUtils/GraphTopologicalSort.hpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnnUtils/GraphTopologicalSort.hpp')
-rw-r--r--src/armnnUtils/GraphTopologicalSort.hpp86
1 files changed, 63 insertions, 23 deletions
diff --git a/src/armnnUtils/GraphTopologicalSort.hpp b/src/armnnUtils/GraphTopologicalSort.hpp
index f455289567..86eb4cc030 100644
--- a/src/armnnUtils/GraphTopologicalSort.hpp
+++ b/src/armnnUtils/GraphTopologicalSort.hpp
@@ -5,11 +5,14 @@
#pragma once
#include <boost/assert.hpp>
+#include <boost/optional.hpp>
#include <functional>
#include <map>
+#include <stack>
#include <vector>
+
namespace armnnUtils
{
@@ -22,51 +25,88 @@ enum class NodeState
Visited,
};
-template<typename TNodeId>
-bool Visit(
- TNodeId current,
- std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
- std::vector<TNodeId>& outSorted,
- std::map<TNodeId, NodeState>& nodeStates)
+
+template <typename TNodeId>
+boost::optional<TNodeId> GetNextChild(TNodeId node,
+ std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
+ std::map<TNodeId, NodeState>& nodeStates)
{
- auto currentStateIt = nodeStates.find(current);
- if (currentStateIt != nodeStates.end())
+ for (TNodeId childNode : getIncomingEdges(node))
{
- if (currentStateIt->second == NodeState::Visited)
- {
- return true;
- }
- if (currentStateIt->second == NodeState::Visiting)
+ if (nodeStates.find(childNode) == nodeStates.end())
{
- return false;
+ return childNode;
}
else
{
- BOOST_ASSERT(false);
+ if (nodeStates.find(childNode)->second == NodeState::Visiting)
+ {
+ return childNode;
+ }
}
}
- nodeStates[current] = NodeState::Visiting;
+ return {};
+}
- for (TNodeId inputNode : getIncomingEdges(current))
+template<typename TNodeId>
+bool TopologicallySort(
+ TNodeId initialNode,
+ std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
+ std::vector<TNodeId>& outSorted,
+ std::map<TNodeId, NodeState>& nodeStates)
+{
+ std::stack<TNodeId> nodeStack;
+
+ // If the node is never visited we should search it
+ if (nodeStates.find(initialNode) == nodeStates.end())
{
- Visit(inputNode, getIncomingEdges, outSorted, nodeStates);
+ nodeStack.push(initialNode);
}
- nodeStates[current] = NodeState::Visited;
+ while (!nodeStack.empty())
+ {
+ TNodeId current = nodeStack.top();
+
+ nodeStates[current] = NodeState::Visiting;
+
+ boost::optional<TNodeId> nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
+
+ if (nextChildOfCurrent)
+ {
+ TNodeId nextChild = nextChildOfCurrent.get();
+
+ // If the child has not been searched, add to the stack and iterate over this node
+ if (nodeStates.find(nextChild) == nodeStates.end())
+ {
+ nodeStack.push(nextChild);
+ continue;
+ }
+
+ // If we re-encounter a node being visited there is a cycle
+ if (nodeStates[nextChild] == NodeState::Visiting)
+ {
+ return false;
+ }
+ }
+
+ nodeStack.pop();
+
+ nodeStates[current] = NodeState::Visited;
+ outSorted.push_back(current);
+ }
- outSorted.push_back(current);
return true;
}
}
-// Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
+// Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
// it must return the list of nodes which are required to come before it.
// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
-// The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
+// This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
template<typename TNodeId, typename TTargetNodes>
bool GraphTopologicalSort(
const TTargetNodes& targetNodes,
@@ -78,7 +118,7 @@ bool GraphTopologicalSort(
for (TNodeId targetNode : targetNodes)
{
- if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates))
+ if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))
{
return false;
}