diff options
Diffstat (limited to 'src/armnnUtils/GraphTopologicalSort.hpp')
-rw-r--r-- | src/armnnUtils/GraphTopologicalSort.hpp | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/src/armnnUtils/GraphTopologicalSort.hpp b/src/armnnUtils/GraphTopologicalSort.hpp new file mode 100644 index 0000000000..f455289567 --- /dev/null +++ b/src/armnnUtils/GraphTopologicalSort.hpp @@ -0,0 +1,90 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include <boost/assert.hpp> + +#include <functional> +#include <map> +#include <vector> + +namespace armnnUtils +{ + +namespace +{ + +enum class NodeState +{ + Visiting, + Visited, +}; + +template<typename TNodeId> +bool Visit( + TNodeId current, + std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges, + std::vector<TNodeId>& outSorted, + std::map<TNodeId, NodeState>& nodeStates) +{ + auto currentStateIt = nodeStates.find(current); + if (currentStateIt != nodeStates.end()) + { + if (currentStateIt->second == NodeState::Visited) + { + return true; + } + if (currentStateIt->second == NodeState::Visiting) + { + return false; + } + else + { + BOOST_ASSERT(false); + } + } + + nodeStates[current] = NodeState::Visiting; + + for (TNodeId inputNode : getIncomingEdges(current)) + { + Visit(inputNode, getIncomingEdges, outSorted, nodeStates); + } + + nodeStates[current] = NodeState::Visited; + + outSorted.push_back(current); + return true; +} + +} + +// Sorts an directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself. +// Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle). +// The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node, +// it must return the list of nodes which are required to come before it. +// "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate. +// The implementation is based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search +template<typename TNodeId, typename TTargetNodes> +bool GraphTopologicalSort( + const TTargetNodes& targetNodes, + std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges, + std::vector<TNodeId>& outSorted) +{ + outSorted.clear(); + std::map<TNodeId, NodeState> nodeStates; + + for (TNodeId targetNode : targetNodes) + { + if (!Visit(targetNode, getIncomingEdges, outSorted, nodeStates)) + { + return false; + } + } + + return true; +} + +}
\ No newline at end of file |