From c577f2c6a3b4ddb6ba87a882723c53a248afbeba Mon Sep 17 00:00:00 2001 From: telsoa01 Date: Fri, 31 Aug 2018 09:22:23 +0100 Subject: Release 18.08 --- src/armnnUtils/GraphTopologicalSort.hpp | 86 ++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 23 deletions(-) (limited to 'src/armnnUtils/GraphTopologicalSort.hpp') diff --git a/src/armnnUtils/GraphTopologicalSort.hpp b/src/armnnUtils/GraphTopologicalSort.hpp index f455289567..86eb4cc030 100644 --- a/src/armnnUtils/GraphTopologicalSort.hpp +++ b/src/armnnUtils/GraphTopologicalSort.hpp @@ -5,11 +5,14 @@ #pragma once #include +#include #include #include +#include #include + namespace armnnUtils { @@ -22,51 +25,88 @@ enum class NodeState Visited, }; -template -bool Visit( - TNodeId current, - std::function(TNodeId)> getIncomingEdges, - std::vector& outSorted, - std::map& nodeStates) + +template +boost::optional GetNextChild(TNodeId node, + std::function(TNodeId)> getIncomingEdges, + std::map& nodeStates) { - auto currentStateIt = nodeStates.find(current); - if (currentStateIt != nodeStates.end()) + for (TNodeId childNode : getIncomingEdges(node)) { - if (currentStateIt->second == NodeState::Visited) - { - return true; - } - if (currentStateIt->second == NodeState::Visiting) + if (nodeStates.find(childNode) == nodeStates.end()) { - return false; + return childNode; } else { - BOOST_ASSERT(false); + if (nodeStates.find(childNode)->second == NodeState::Visiting) + { + return childNode; + } } } - nodeStates[current] = NodeState::Visiting; + return {}; +} - for (TNodeId inputNode : getIncomingEdges(current)) +template +bool TopologicallySort( + TNodeId initialNode, + std::function(TNodeId)> getIncomingEdges, + std::vector& outSorted, + std::map& nodeStates) +{ + std::stack nodeStack; + + // If the node is never visited we should search it + if (nodeStates.find(initialNode) == nodeStates.end()) { - Visit(inputNode, getIncomingEdges, outSorted, nodeStates); + nodeStack.push(initialNode); } - nodeStates[current] = NodeState::Visited; + while (!nodeStack.empty()) + { + TNodeId current = nodeStack.top(); + + nodeStates[current] = NodeState::Visiting; + + boost::optional nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates); + + if (nextChildOfCurrent) + { + TNodeId nextChild = nextChildOfCurrent.get(); + + // If the child has not been searched, add to the stack and iterate over this node + if (nodeStates.find(nextChild) == nodeStates.end()) + { + nodeStack.push(nextChild); + continue; + } + + // If we re-encounter a node being visited there is a cycle + if (nodeStates[nextChild] == NodeState::Visiting) + { + return false; + } + } + + nodeStack.pop(); + + nodeStates[current] = NodeState::Visited; + outSorted.push_back(current); + } - outSorted.push_back(current); return true; } } -// Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself. +// Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself. // Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle). // The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node, // it must return the list of nodes which are required to come before it. // "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate. -// The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search +// This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search template bool GraphTopologicalSort( const TTargetNodes& targetNodes, @@ -78,7 +118,7 @@ bool GraphTopologicalSort( for (TNodeId targetNode : targetNodes) { - if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates)) + if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates)) { return false; } -- cgit v1.2.1