aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphViewSelector.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/SubgraphViewSelector.cpp')
-rw-r--r--src/armnn/SubgraphViewSelector.cpp297
1 files changed, 229 insertions, 68 deletions
diff --git a/src/armnn/SubgraphViewSelector.cpp b/src/armnn/SubgraphViewSelector.cpp
index 4357ec4381..8798b7285d 100644
--- a/src/armnn/SubgraphViewSelector.cpp
+++ b/src/armnn/SubgraphViewSelector.cpp
@@ -9,6 +9,7 @@
#include <algorithm>
#include <map>
#include <queue>
+#include <unordered_set>
namespace armnn
{
@@ -16,28 +17,170 @@ namespace armnn
namespace
{
+/// Intermediate data-structure to store the subgraph that a layer has been assigned to.
+/// This is a "disjoint set" data structure that allows efficient merging of subgraphs,
+/// which is a key part of the algorithm. Subgraphs are arranged in singly-linked trees
+/// (with each node storing a pointer to its parent). Subgraphs in the same tree are considered
+/// to have been merged. Merging subgraphs is performed by attaching one tree to another,
+/// which is a simple pointer update.
+///
+/// NOTE: Due to the way this is stored, it is almost never correct to directly compare pointers
+/// to two PartialSubgraphs to check if two layers belong in the same subgraph. Instead you
+/// should use IsMergedWith().
+///
+/// This structure also stores information about the dependencies of each subgraph, which is needed
+/// to determine whether certain subgraphs can be merged. Checking whether a subgraph
+/// depends on another subgraph is a frequent operation in the algorithm (see AssignSplitId) and so this is optimized
+/// in preference to the merging of subgraphs. This leads to an approach where each subgraph stores
+/// a set of all the subgraphs it depends on (for a fast lookup). In order to efficiently update this
+/// set as subgraphs are merged means we also store a set of subgraphs which *depend on us* (i.e. the
+/// complement of our dependencies).
+class PartialSubgraph
+{
+public:
+ /// If this subgraph has been merged with another then there is an agreed "representative" for the combined
+ /// subgraph, which uniquely identifies the subgraph.
+ PartialSubgraph* GetRepresentative()
+ {
+ // Recurse up the tree to find the root node.
+ if (m_Parent == nullptr)
+ {
+ return this;
+ }
+ else
+ {
+ PartialSubgraph* result = m_Parent->GetRepresentative();
+ // Update our parent pointer to point directly to the root in order to speed up future calls to this method.
+ // This essentially "flattens" the tree.
+ m_Parent = result;
+ return result;
+ }
+ }
+
+ /// Merges this subgraph with another.
+ void MergeWith(PartialSubgraph* other)
+ {
+ if (m_Parent == nullptr)
+ {
+ other = other->GetRepresentative();
+ if (this == other)
+ {
+ // Already merged - no-op
+ return;
+ }
+ m_Parent = other;
+
+ // Update others' dependency sets to point to the new representative rather than us.
+ // Keeping these up-to-date means we can rely on these sets containing representatives when
+ // we perform a lookup in HasAntecedent() and so don't need to resolve the representative for each element
+ // of the set. See description at the top of this class for more rationale.
+ for (PartialSubgraph* a : m_Antecedents)
+ {
+ size_t numErased = a->m_Dependants.erase(this);
+ BOOST_ASSERT(numErased == 1);
+ boost::ignore_unused(numErased);
+ a->m_Dependants.insert(m_Parent);
+ }
+ for (PartialSubgraph* a : m_Dependants)
+ {
+ size_t numErased = a->m_Antecedents.erase(this);
+ BOOST_ASSERT(numErased == 1);
+ boost::ignore_unused(numErased);
+ a->m_Antecedents.insert(m_Parent);
+ }
+
+ // Merge our dependency sets into our new representative.
+ // We no longer need to maintain our own sets, as requests will always be forwarded to the representative.
+ m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
+ m_Antecedents.clear();
+ m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
+ m_Dependants.clear();
+ }
+ else
+ {
+ // Defer request to the representative
+ GetRepresentative()->MergeWith(other);
+ }
+ }
+
+ /// Checks if this subgraph has been merged with the given subgraph.
+ bool IsMergedWith(PartialSubgraph* other)
+ {
+ return GetRepresentative() == other->GetRepresentative();
+ }
+
+ /// Marks the given subgraph as a direct antecedent (dependency) of this one.
+ void AddDirectAntecedent(PartialSubgraph* antecedent)
+ {
+ if (m_Parent == nullptr)
+ {
+ antecedent = antecedent->GetRepresentative();
+
+ m_Antecedents.insert(antecedent);
+ // Also record all of its antecedents, so that we end up with direct and indirect antecedents.
+ // This makes the lookup in HasAntecedent() faster.
+ m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
+ // All of our dependents also need to include the new antecedents
+ for (PartialSubgraph* d : m_Dependants)
+ {
+ d->m_Antecedents.insert(antecedent);
+ d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
+ }
+
+ // Store reverse dependencies as well, required so that we can efficiently navigate the graph
+ // when making updates.
+ antecedent->m_Dependants.insert(this);
+ antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
+ for (PartialSubgraph* a : antecedent->m_Antecedents)
+ {
+ a->m_Dependants.insert(this);
+ a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
+ }
+ }
+ else
+ {
+ // Defer request to the representative
+ GetRepresentative()->AddDirectAntecedent(antecedent);
+ }
+ }
+
+ /// Checks if this subgraph is dependent on the given subgraph, either directly or indirectly.
+ bool HasAntecedent(PartialSubgraph* antecedent)
+ {
+ if (m_Parent == nullptr)
+ {
+ antecedent = antecedent->GetRepresentative();
+ // Thanks to keeping this set updated in MergeWith and AddDirectAntecedent, we can do an efficient lookup.
+ return m_Antecedents.count(antecedent) > 0;
+ }
+ else
+ {
+ // Defer request to the representative
+ return GetRepresentative()->HasAntecedent(antecedent);
+ }
+ }
+
+private:
+ /// Pointer to the parent node in the tree. If this is null then we are the representative for our merged subgraph.
+ PartialSubgraph* m_Parent;
+ /// The representatives of all the subgraphs which we depend on, either directly or indirectly.
+ std::unordered_set<PartialSubgraph*> m_Antecedents;
+ /// The representatives of all the subgraphs which depend on us, either directly or indirectly.
+ std::unordered_set<PartialSubgraph*> m_Dependants;
+};
+
+/// Intermediate data structure to store information associated with a particular layer.
struct LayerSelectionInfo
{
- using SplitId = uint32_t;
using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
- static constexpr uint32_t InitialSplitId() { return 1; }
LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
: m_Layer{layer}
- , m_SplitId{0}
+ , m_Subgraph{nullptr}
, m_IsSelected{selector(*layer)}
, m_IsProcessed(false)
{
- // fill topology information by storing direct children
- for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
- {
- for (InputSlot* childLayerInputSlot : slot->GetConnections())
- {
- Layer& childLayer = childLayerInputSlot->GetOwningLayer();
- m_DirectChildren.push_back(&childLayer);
- }
- }
}
bool IsInputLayer() const
@@ -57,7 +200,7 @@ struct LayerSelectionInfo
Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
auto parentInfo = layerInfos.find(&parentLayer);
if (parentInfo == layerInfos.end() ||
- m_SplitId != parentInfo->second.m_SplitId)
+ !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
{
// Avoid collecting duplicate input slots
InputSlot* inputSlot = &(*slot);
@@ -80,7 +223,7 @@ struct LayerSelectionInfo
Layer& childLayer = childLayerInputSlot->GetOwningLayer();
auto childInfo = layerInfos.find(&childLayer);
if (childInfo == layerInfos.end() ||
- m_SplitId != childInfo->second.m_SplitId)
+ !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
{
// Avoid collecting duplicate output slots
OutputSlot* outputSlot = &(*slot);
@@ -93,9 +236,11 @@ struct LayerSelectionInfo
}
}
- std::vector<Layer*> m_DirectChildren;
Layer* m_Layer;
- SplitId m_SplitId;
+ /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true.
+ /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph -
+ /// see the description of the PartialSubgraph class.
+ std::shared_ptr<PartialSubgraph> m_Subgraph;
bool m_IsSelected;
bool m_IsProcessed;
};
@@ -155,48 +300,67 @@ void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
{
- bool newSplit = false;
- LayerSelectionInfo::SplitId minSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::max();
- LayerSelectionInfo::SplitId maxSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest();
- LayerSelectionInfo::SplitId maxSelectableId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest();
-
- ForEachLayerInput(layerInfos, layerInfo, [&newSplit, &minSplitId, &maxSplitId, &maxSelectableId, &layerInfo](
- LayerSelectionInfo& parentInfo)
+ // Check each input to see if we can attach ourselves to any of the subgraphs that have already been assigned.
+ ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
+ {
+ // We can only attach ourselves to the subgraph from this input if there isn't a cut here.
+ if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
{
- minSplitId = std::min(minSplitId, parentInfo.m_SplitId);
- maxSplitId = std::max(maxSplitId, parentInfo.m_SplitId);
- if (parentInfo.m_IsSelected && layerInfo.m_IsSelected)
+ // We also need to check that merging into this subgraph won't cause a dependency cycle between subgraphs.
+ // This will be the case if the subgraph that we will become part of is already a dependency
+ // of one of the subgraphs that are input to this layer, e.g:
+ //
+ // 0 | The numbers (0, 1) are the subgraph IDs of each layer and we are looking at layer X.
+ // / \ |
+ // 1 0 | We can't merge X into subgraph 0, because the left-hand input already depends on subgraph 0.
+ // \ / | We can however merge X into subgraph 1.
+ // X |
+ //
+ bool dependenciesOk = true;
+ ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
{
- maxSelectableId = std::max(maxSelectableId, parentInfo.m_SplitId);
- }
+ // We call HasAntecedent() ~ n^2 times, where n is the number of inputs to this layer.
+ // Hence it is important that this is efficient - see PartialSubgraph class description.
+ if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
+ {
+ dependenciesOk = false;
+ }
+ });
- if (layerInfo.m_IsSelected != parentInfo.m_IsSelected)
+ if (dependenciesOk)
{
- newSplit = true;
+ // Merge into the subgraph of this input. If we have already been merged into another subgraph
+ // (from another input of this layer), then merge both of them together.
+ if (layerInfo.m_Subgraph == nullptr)
+ {
+ layerInfo.m_Subgraph = parentInfo.m_Subgraph;
+ }
+ else
+ {
+ // We call MergeWith() ~ n times, where n is the number of inputs to this layer.
+ // Therefore it does not need to be as performant as HasAntecedent().
+ layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
+ }
}
+ }
+ });
- });
+ // If we weren't able to merge into an existing subgraph then we need to make a new one
+ if (layerInfo.m_Subgraph == nullptr)
+ {
+ layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
+ }
- // Assign the split Id for the current layerInfo
- if (newSplit)
+ // Record dependencies of the chosen subgraph based on the inputs of this layer.
+ ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
{
- if (maxSelectableId > minSplitId)
+ // These functions are called ~n times, where n is the number of inputs to this layer.
+ // Therefore it does not need to be as performant as HasAntecedent().
+ if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
{
- // We can be overly aggressive when choosing to create a new split so
- // here we determine if one of the parent branches are suitable candidates for continuation instead.
- // Any splitId > minSplitId will come from a shorter branch...and therefore should not be from
- // the split containing the original fork and thus we avoid the execution dependency.
- layerInfo.m_SplitId = maxSelectableId;
+ layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
}
- else
- {
- layerInfo.m_SplitId = ++maxSplitId;
- }
- } else
- {
- // The branch with the highest splitId represents the shortest path of selected nodes.
- layerInfo.m_SplitId = maxSplitId;
- }
+ });
}
bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
@@ -249,7 +413,6 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
// This layerInfo may have been added to the queue multiple times, so skip if we have already processed it
if (!layerInfo.m_IsProcessed)
{
-
// Only process this layerInfo if all inputs have been processed
if (!IsReadyForSplitAssignment(layerInfos, layerInfo))
{
@@ -272,17 +435,18 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
}
}
- // Collect all selected layers keyed by split id into a map
+ // Collect all selected layers keyed by subgraph representative into a map
using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
- std::map<uint32_t, SelectionInfoPtrs> splitMap;
+ std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
for (auto& info : layerInfos)
{
if (info.second.m_IsSelected)
{
- auto it = splitMap.find(info.second.m_SplitId);
+ auto it = splitMap.find(info.second.m_Subgraph->GetRepresentative());
if (it == splitMap.end())
{
- splitMap.insert(std::make_pair(info.second.m_SplitId, SelectionInfoPtrs{&info.second}));
+ splitMap.insert(
+ std::make_pair(info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second}));
}
else
{
@@ -291,26 +455,23 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
}
}
- // Now each non-empty split id represents a subgraph
+ // Now each entry in splitMap represents a subgraph
Subgraphs result;
for (auto& splitGraph : splitMap)
{
- if (splitGraph.second.empty() == false)
+ SubgraphView::InputSlots inputs;
+ SubgraphView::OutputSlots outputs;
+ SubgraphView::Layers layers;
+ for (auto&& infoPtr : splitGraph.second)
{
- SubgraphView::InputSlots inputs;
- SubgraphView::OutputSlots outputs;
- SubgraphView::Layers layers;
- for (auto&& infoPtr : splitGraph.second)
- {
- infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
- infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
- layers.push_back(infoPtr->m_Layer);
- }
- // Create a new sub-graph with the new lists of input/output slots and layer
- result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
- std::move(outputs),
- std::move(layers)));
+ infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
+ infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
+ layers.push_back(infoPtr->m_Layer);
}
+ // Create a new sub-graph with the new lists of input/output slots and layer
+ result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
+ std::move(outputs),
+ std::move(layers)));
}
return result;