diff options
Diffstat (limited to 'src/armnn/SubgraphViewSelector.cpp')
-rw-r--r-- | src/armnn/SubgraphViewSelector.cpp | 297 |
1 files changed, 229 insertions, 68 deletions
diff --git a/src/armnn/SubgraphViewSelector.cpp b/src/armnn/SubgraphViewSelector.cpp index 4357ec4381..8798b7285d 100644 --- a/src/armnn/SubgraphViewSelector.cpp +++ b/src/armnn/SubgraphViewSelector.cpp @@ -9,6 +9,7 @@ #include <algorithm> #include <map> #include <queue> +#include <unordered_set> namespace armnn { @@ -16,28 +17,170 @@ namespace armnn namespace { +/// Intermediate data-structure to store the subgraph that a layer has been assigned to. +/// This is a "disjoint set" data structure that allows efficient merging of subgraphs, +/// which is a key part of the algorithm. Subgraphs are arranged in singly-linked trees +/// (with each node storing a pointer to its parent). Subgraphs in the same tree are considered +/// to have been merged. Merging subgraphs is performed by attaching one tree to another, +/// which is a simple pointer update. +/// +/// NOTE: Due to the way this is stored, it is almost never correct to directly compare pointers +/// to two PartialSubgraphs to check if two layers belong in the same subgraph. Instead you +/// should use IsMergedWith(). +/// +/// This structure also stores information about the dependencies of each subgraph, which is needed +/// to determine whether certain subgraphs can be merged. Checking whether a subgraph +/// depends on another subgraph is a frequent operation in the algorithm (see AssignSplitId) and so this is optimized +/// in preference to the merging of subgraphs. This leads to an approach where each subgraph stores +/// a set of all the subgraphs it depends on (for a fast lookup). In order to efficiently update this +/// set as subgraphs are merged means we also store a set of subgraphs which *depend on us* (i.e. the +/// complement of our dependencies). +class PartialSubgraph +{ +public: + /// If this subgraph has been merged with another then there is an agreed "representative" for the combined + /// subgraph, which uniquely identifies the subgraph. + PartialSubgraph* GetRepresentative() + { + // Recurse up the tree to find the root node. + if (m_Parent == nullptr) + { + return this; + } + else + { + PartialSubgraph* result = m_Parent->GetRepresentative(); + // Update our parent pointer to point directly to the root in order to speed up future calls to this method. + // This essentially "flattens" the tree. + m_Parent = result; + return result; + } + } + + /// Merges this subgraph with another. + void MergeWith(PartialSubgraph* other) + { + if (m_Parent == nullptr) + { + other = other->GetRepresentative(); + if (this == other) + { + // Already merged - no-op + return; + } + m_Parent = other; + + // Update others' dependency sets to point to the new representative rather than us. + // Keeping these up-to-date means we can rely on these sets containing representatives when + // we perform a lookup in HasAntecedent() and so don't need to resolve the representative for each element + // of the set. See description at the top of this class for more rationale. + for (PartialSubgraph* a : m_Antecedents) + { + size_t numErased = a->m_Dependants.erase(this); + BOOST_ASSERT(numErased == 1); + boost::ignore_unused(numErased); + a->m_Dependants.insert(m_Parent); + } + for (PartialSubgraph* a : m_Dependants) + { + size_t numErased = a->m_Antecedents.erase(this); + BOOST_ASSERT(numErased == 1); + boost::ignore_unused(numErased); + a->m_Antecedents.insert(m_Parent); + } + + // Merge our dependency sets into our new representative. + // We no longer need to maintain our own sets, as requests will always be forwarded to the representative. + m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end()); + m_Antecedents.clear(); + m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end()); + m_Dependants.clear(); + } + else + { + // Defer request to the representative + GetRepresentative()->MergeWith(other); + } + } + + /// Checks if this subgraph has been merged with the given subgraph. + bool IsMergedWith(PartialSubgraph* other) + { + return GetRepresentative() == other->GetRepresentative(); + } + + /// Marks the given subgraph as a direct antecedent (dependency) of this one. + void AddDirectAntecedent(PartialSubgraph* antecedent) + { + if (m_Parent == nullptr) + { + antecedent = antecedent->GetRepresentative(); + + m_Antecedents.insert(antecedent); + // Also record all of its antecedents, so that we end up with direct and indirect antecedents. + // This makes the lookup in HasAntecedent() faster. + m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end()); + // All of our dependents also need to include the new antecedents + for (PartialSubgraph* d : m_Dependants) + { + d->m_Antecedents.insert(antecedent); + d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end()); + } + + // Store reverse dependencies as well, required so that we can efficiently navigate the graph + // when making updates. + antecedent->m_Dependants.insert(this); + antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end()); + for (PartialSubgraph* a : antecedent->m_Antecedents) + { + a->m_Dependants.insert(this); + a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end()); + } + } + else + { + // Defer request to the representative + GetRepresentative()->AddDirectAntecedent(antecedent); + } + } + + /// Checks if this subgraph is dependent on the given subgraph, either directly or indirectly. + bool HasAntecedent(PartialSubgraph* antecedent) + { + if (m_Parent == nullptr) + { + antecedent = antecedent->GetRepresentative(); + // Thanks to keeping this set updated in MergeWith and AddDirectAntecedent, we can do an efficient lookup. + return m_Antecedents.count(antecedent) > 0; + } + else + { + // Defer request to the representative + return GetRepresentative()->HasAntecedent(antecedent); + } + } + +private: + /// Pointer to the parent node in the tree. If this is null then we are the representative for our merged subgraph. + PartialSubgraph* m_Parent; + /// The representatives of all the subgraphs which we depend on, either directly or indirectly. + std::unordered_set<PartialSubgraph*> m_Antecedents; + /// The representatives of all the subgraphs which depend on us, either directly or indirectly. + std::unordered_set<PartialSubgraph*> m_Dependants; +}; + +/// Intermediate data structure to store information associated with a particular layer. struct LayerSelectionInfo { - using SplitId = uint32_t; using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>; using LayerInfoQueue = std::queue<LayerSelectionInfo*>; - static constexpr uint32_t InitialSplitId() { return 1; } LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector) : m_Layer{layer} - , m_SplitId{0} + , m_Subgraph{nullptr} , m_IsSelected{selector(*layer)} , m_IsProcessed(false) { - // fill topology information by storing direct children - for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot) - { - for (InputSlot* childLayerInputSlot : slot->GetConnections()) - { - Layer& childLayer = childLayerInputSlot->GetOwningLayer(); - m_DirectChildren.push_back(&childLayer); - } - } } bool IsInputLayer() const @@ -57,7 +200,7 @@ struct LayerSelectionInfo Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer(); auto parentInfo = layerInfos.find(&parentLayer); if (parentInfo == layerInfos.end() || - m_SplitId != parentInfo->second.m_SplitId) + !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get())) { // Avoid collecting duplicate input slots InputSlot* inputSlot = &(*slot); @@ -80,7 +223,7 @@ struct LayerSelectionInfo Layer& childLayer = childLayerInputSlot->GetOwningLayer(); auto childInfo = layerInfos.find(&childLayer); if (childInfo == layerInfos.end() || - m_SplitId != childInfo->second.m_SplitId) + !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get())) { // Avoid collecting duplicate output slots OutputSlot* outputSlot = &(*slot); @@ -93,9 +236,11 @@ struct LayerSelectionInfo } } - std::vector<Layer*> m_DirectChildren; Layer* m_Layer; - SplitId m_SplitId; + /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true. + /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph - + /// see the description of the PartialSubgraph class. + std::shared_ptr<PartialSubgraph> m_Subgraph; bool m_IsSelected; bool m_IsProcessed; }; @@ -155,48 +300,67 @@ void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos, void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo) { - bool newSplit = false; - LayerSelectionInfo::SplitId minSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::max(); - LayerSelectionInfo::SplitId maxSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest(); - LayerSelectionInfo::SplitId maxSelectableId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest(); - - ForEachLayerInput(layerInfos, layerInfo, [&newSplit, &minSplitId, &maxSplitId, &maxSelectableId, &layerInfo]( - LayerSelectionInfo& parentInfo) + // Check each input to see if we can attach ourselves to any of the subgraphs that have already been assigned. + ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo) + { + // We can only attach ourselves to the subgraph from this input if there isn't a cut here. + if (layerInfo.m_IsSelected == parentInfo.m_IsSelected) { - minSplitId = std::min(minSplitId, parentInfo.m_SplitId); - maxSplitId = std::max(maxSplitId, parentInfo.m_SplitId); - if (parentInfo.m_IsSelected && layerInfo.m_IsSelected) + // We also need to check that merging into this subgraph won't cause a dependency cycle between subgraphs. + // This will be the case if the subgraph that we will become part of is already a dependency + // of one of the subgraphs that are input to this layer, e.g: + // + // 0 | The numbers (0, 1) are the subgraph IDs of each layer and we are looking at layer X. + // / \ | + // 1 0 | We can't merge X into subgraph 0, because the left-hand input already depends on subgraph 0. + // \ / | We can however merge X into subgraph 1. + // X | + // + bool dependenciesOk = true; + ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo) { - maxSelectableId = std::max(maxSelectableId, parentInfo.m_SplitId); - } + // We call HasAntecedent() ~ n^2 times, where n is the number of inputs to this layer. + // Hence it is important that this is efficient - see PartialSubgraph class description. + if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get())) + { + dependenciesOk = false; + } + }); - if (layerInfo.m_IsSelected != parentInfo.m_IsSelected) + if (dependenciesOk) { - newSplit = true; + // Merge into the subgraph of this input. If we have already been merged into another subgraph + // (from another input of this layer), then merge both of them together. + if (layerInfo.m_Subgraph == nullptr) + { + layerInfo.m_Subgraph = parentInfo.m_Subgraph; + } + else + { + // We call MergeWith() ~ n times, where n is the number of inputs to this layer. + // Therefore it does not need to be as performant as HasAntecedent(). + layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get()); + } } + } + }); - }); + // If we weren't able to merge into an existing subgraph then we need to make a new one + if (layerInfo.m_Subgraph == nullptr) + { + layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>(); + } - // Assign the split Id for the current layerInfo - if (newSplit) + // Record dependencies of the chosen subgraph based on the inputs of this layer. + ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo) { - if (maxSelectableId > minSplitId) + // These functions are called ~n times, where n is the number of inputs to this layer. + // Therefore it does not need to be as performant as HasAntecedent(). + if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get())) { - // We can be overly aggressive when choosing to create a new split so - // here we determine if one of the parent branches are suitable candidates for continuation instead. - // Any splitId > minSplitId will come from a shorter branch...and therefore should not be from - // the split containing the original fork and thus we avoid the execution dependency. - layerInfo.m_SplitId = maxSelectableId; + layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get()); } - else - { - layerInfo.m_SplitId = ++maxSplitId; - } - } else - { - // The branch with the highest splitId represents the shortest path of selected nodes. - layerInfo.m_SplitId = maxSplitId; - } + }); } bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo) @@ -249,7 +413,6 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it if (!layerInfo.m_IsProcessed) { - // Only process this layerInfo if all inputs have been processed if (!IsReadyForSplitAssignment(layerInfos, layerInfo)) { @@ -272,17 +435,18 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto } } - // Collect all selected layers keyed by split id into a map + // Collect all selected layers keyed by subgraph representative into a map using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>; - std::map<uint32_t, SelectionInfoPtrs> splitMap; + std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap; for (auto& info : layerInfos) { if (info.second.m_IsSelected) { - auto it = splitMap.find(info.second.m_SplitId); + auto it = splitMap.find(info.second.m_Subgraph->GetRepresentative()); if (it == splitMap.end()) { - splitMap.insert(std::make_pair(info.second.m_SplitId, SelectionInfoPtrs{&info.second})); + splitMap.insert( + std::make_pair(info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second})); } else { @@ -291,26 +455,23 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto } } - // Now each non-empty split id represents a subgraph + // Now each entry in splitMap represents a subgraph Subgraphs result; for (auto& splitGraph : splitMap) { - if (splitGraph.second.empty() == false) + SubgraphView::InputSlots inputs; + SubgraphView::OutputSlots outputs; + SubgraphView::Layers layers; + for (auto&& infoPtr : splitGraph.second) { - SubgraphView::InputSlots inputs; - SubgraphView::OutputSlots outputs; - SubgraphView::Layers layers; - for (auto&& infoPtr : splitGraph.second) - { - infoPtr->CollectNonSelectedInputs(layerInfos, inputs); - infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs); - layers.push_back(infoPtr->m_Layer); - } - // Create a new sub-graph with the new lists of input/output slots and layer - result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs), - std::move(outputs), - std::move(layers))); + infoPtr->CollectNonSelectedInputs(layerInfos, inputs); + infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs); + layers.push_back(infoPtr->m_Layer); } + // Create a new sub-graph with the new lists of input/output slots and layer + result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs), + std::move(outputs), + std::move(layers))); } return result; |