diff options
Diffstat (limited to 'src/armnnUtils/GraphTopologicalSort.hpp')
-rw-r--r-- | src/armnnUtils/GraphTopologicalSort.hpp | 86 |
1 files changed, 63 insertions, 23 deletions
diff --git a/src/armnnUtils/GraphTopologicalSort.hpp b/src/armnnUtils/GraphTopologicalSort.hpp index f455289567..86eb4cc030 100644 --- a/src/armnnUtils/GraphTopologicalSort.hpp +++ b/src/armnnUtils/GraphTopologicalSort.hpp @@ -5,11 +5,14 @@ #pragma once #include <boost/assert.hpp> +#include <boost/optional.hpp> #include <functional> #include <map> +#include <stack> #include <vector> + namespace armnnUtils { @@ -22,51 +25,88 @@ enum class NodeState Visited, }; -template<typename TNodeId> -bool Visit( - TNodeId current, - std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges, - std::vector<TNodeId>& outSorted, - std::map<TNodeId, NodeState>& nodeStates) + +template <typename TNodeId> +boost::optional<TNodeId> GetNextChild(TNodeId node, + std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges, + std::map<TNodeId, NodeState>& nodeStates) { - auto currentStateIt = nodeStates.find(current); - if (currentStateIt != nodeStates.end()) + for (TNodeId childNode : getIncomingEdges(node)) { - if (currentStateIt->second == NodeState::Visited) - { - return true; - } - if (currentStateIt->second == NodeState::Visiting) + if (nodeStates.find(childNode) == nodeStates.end()) { - return false; + return childNode; } else { - BOOST_ASSERT(false); + if (nodeStates.find(childNode)->second == NodeState::Visiting) + { + return childNode; + } } } - nodeStates[current] = NodeState::Visiting; + return {}; +} - for (TNodeId inputNode : getIncomingEdges(current)) +template<typename TNodeId> +bool TopologicallySort( + TNodeId initialNode, + std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges, + std::vector<TNodeId>& outSorted, + std::map<TNodeId, NodeState>& nodeStates) +{ + std::stack<TNodeId> nodeStack; + + // If the node is never visited we should search it + if (nodeStates.find(initialNode) == nodeStates.end()) { - Visit(inputNode, getIncomingEdges, outSorted, nodeStates); + nodeStack.push(initialNode); } - nodeStates[current] = NodeState::Visited; + while (!nodeStack.empty()) + { + TNodeId current = nodeStack.top(); + + nodeStates[current] = NodeState::Visiting; + + boost::optional<TNodeId> nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates); + + if (nextChildOfCurrent) + { + TNodeId nextChild = nextChildOfCurrent.get(); + + // If the child has not been searched, add to the stack and iterate over this node + if (nodeStates.find(nextChild) == nodeStates.end()) + { + nodeStack.push(nextChild); + continue; + } + + // If we re-encounter a node being visited there is a cycle + if (nodeStates[nextChild] == NodeState::Visiting) + { + return false; + } + } + + nodeStack.pop(); + + nodeStates[current] = NodeState::Visited; + outSorted.push_back(current); + } - outSorted.push_back(current); return true; } } -// Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself. +// Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself. // Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle). // The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node, // it must return the list of nodes which are required to come before it. // "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate. -// The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search +// This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search template<typename TNodeId, typename TTargetNodes> bool GraphTopologicalSort( const TTargetNodes& targetNodes, @@ -78,7 +118,7 @@ bool GraphTopologicalSort( for (TNodeId targetNode : targetNodes) { - if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates)) + if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates)) { return false; } |