// // Copyright © 2017 Arm Ltd. All rights reserved. // See LICENSE file in the project root for full license information. // #pragma once #include #include #include #include namespace armnnUtils { namespace { enum class NodeState { Visiting, Visited, }; template bool Visit( TNodeId current, std::function(TNodeId)> getIncomingEdges, std::vector& outSorted, std::map& nodeStates) { auto currentStateIt = nodeStates.find(current); if (currentStateIt != nodeStates.end()) { if (currentStateIt->second == NodeState::Visited) { return true; } if (currentStateIt->second == NodeState::Visiting) { return false; } else { BOOST_ASSERT(false); } } nodeStates[current] = NodeState::Visiting; for (TNodeId inputNode : getIncomingEdges(current)) { Visit(inputNode, getIncomingEdges, outSorted, nodeStates); } nodeStates[current] = NodeState::Visited; outSorted.push_back(current); return true; } } // Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself. // Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle). // The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node, // it must return the list of nodes which are required to come before it. // "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate. // The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search template bool GraphTopologicalSort( const TTargetNodes& targetNodes, std::function(TNodeId)> getIncomingEdges, std::vector& outSorted) { outSorted.clear(); std::map nodeStates; for (TNodeId targetNode : targetNodes) { if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates)) { return false; } } return true; } }