16 #include <unordered_set> 47 PartialSubgraph* GetRepresentative()
50 if (m_Parent ==
nullptr)
56 PartialSubgraph* result = m_Parent->GetRepresentative();
65 void MergeWith(PartialSubgraph* other)
67 if (m_Parent ==
nullptr)
69 other = other->GetRepresentative();
81 for (PartialSubgraph* a : m_Antecedents)
83 size_t numErased = a->m_Dependants.erase(
this);
86 a->m_Dependants.insert(m_Parent);
88 for (PartialSubgraph* a : m_Dependants)
90 size_t numErased = a->m_Antecedents.erase(
this);
93 a->m_Antecedents.insert(m_Parent);
98 m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
99 m_Antecedents.clear();
100 m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
101 m_Dependants.clear();
106 GetRepresentative()->MergeWith(other);
111 bool IsMergedWith(PartialSubgraph* other)
113 return GetRepresentative() == other->GetRepresentative();
117 void AddDirectAntecedent(PartialSubgraph* antecedent)
119 if (m_Parent ==
nullptr)
121 antecedent = antecedent->GetRepresentative();
123 m_Antecedents.insert(antecedent);
126 m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
128 for (PartialSubgraph* d : m_Dependants)
130 d->m_Antecedents.insert(antecedent);
131 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
136 antecedent->m_Dependants.insert(
this);
137 antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
138 for (PartialSubgraph* a : antecedent->m_Antecedents)
140 a->m_Dependants.insert(
this);
141 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
147 GetRepresentative()->AddDirectAntecedent(antecedent);
152 bool HasAntecedent(PartialSubgraph* antecedent)
154 if (m_Parent ==
nullptr)
156 antecedent = antecedent->GetRepresentative();
158 return m_Antecedents.count(antecedent) > 0;
163 return GetRepresentative()->HasAntecedent(antecedent);
169 PartialSubgraph* m_Parent;
171 std::unordered_set<PartialSubgraph*> m_Antecedents;
173 std::unordered_set<PartialSubgraph*> m_Dependants;
177 struct LayerSelectionInfo
179 using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
180 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
190 bool IsInputLayer()
const 195 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
198 for (
auto&& slot =
m_Layer->BeginInputSlots(); slot !=
m_Layer->EndInputSlots(); ++slot)
200 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
201 ARMNN_ASSERT_MSG(parentLayerOutputSlot !=
nullptr,
"The input slots must be connected here.");
202 if (parentLayerOutputSlot)
204 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
205 auto parentInfo = layerInfos.find(&parentLayer);
206 if (parentInfo == layerInfos.end() ||
207 !
m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
210 InputSlot* inputSlot = &(*slot);
211 if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
213 inputSlots.push_back(inputSlot);
220 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
223 for (
auto&& slot =
m_Layer->BeginOutputSlots(); slot !=
m_Layer->EndOutputSlots(); ++slot)
225 for (InputSlot* childLayerInputSlot : slot->GetConnections())
227 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
228 auto childInfo = layerInfos.find(&childLayer);
229 if (childInfo == layerInfos.end() ||
230 !
m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
233 OutputSlot* outputSlot = &(*slot);
234 if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
236 outputSlots.push_back(outputSlot);
262 template<
typename Delegate>
264 LayerSelectionInfo& layerInfo,
267 Layer& layer = *layerInfo.m_Layer;
271 auto connectedInput = PolymorphicDowncast<OutputSlot*>(inputSlot.GetConnection());
273 Layer& inputLayer = connectedInput->GetOwningLayer();
275 auto parentInfo = layerInfos.find(&inputLayer);
276 if (parentInfo != layerInfos.end())
278 function(parentInfo->second);
283 template<
typename Delegate>
285 LayerSelectionInfo& layerInfo,
288 Layer& layer= *layerInfo.m_Layer;
292 for (
auto& output : outputSlot.GetConnections())
294 Layer& childLayer = output->GetOwningLayer();
296 auto childInfo = layerInfos.find(&childLayer);
297 if (childInfo != layerInfos.end())
299 function(childInfo->second);
305 void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
311 if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
323 bool dependenciesOk = true;
324 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
328 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
330 dependenciesOk = false;
338 if (layerInfo.m_Subgraph ==
nullptr)
340 layerInfo.m_Subgraph = parentInfo.m_Subgraph;
346 layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
353 if (layerInfo.m_Subgraph ==
nullptr)
355 layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
363 if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
365 layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
374 [&ready](LayerSelectionInfo& parentInfo)
376 if (!parentInfo.m_IsProcessed)
387 LayerSelectionInfo::LayerInfoContainer layerInfos;
389 LayerSelectionInfo::LayerInfoQueue processQueue;
390 for (
auto& layer : subgraph)
392 auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector});
393 LayerSelectionInfo& layerInfo = emplaced.first->second;
396 if (layerInfo.IsInputLayer())
398 processQueue.push(&layerInfo);
403 for (
auto& inputSlot : subgraphInputSlots)
405 Layer& layer = inputSlot->GetOwningLayer();
406 auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
407 LayerSelectionInfo& layerInfo = emplaced.first->second;
409 processQueue.push(&layerInfo);
412 while (!processQueue.empty())
414 LayerSelectionInfo& layerInfo = *processQueue.front();
418 if (!layerInfo.m_IsProcessed)
424 processQueue.push(&layerInfo);
432 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
434 processQueue.push(&childInfo);
438 layerInfo.m_IsProcessed =
true;
443 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
444 std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
445 for (
auto&
info : layerInfos)
447 if (
info.second.m_IsSelected)
449 auto it = splitMap.find(
info.second.m_Subgraph->GetRepresentative());
450 if (it == splitMap.end())
453 std::make_pair(
info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&
info.second}));
457 it->second.push_back(&
info.second);
464 for (
auto& splitGraph : splitMap)
469 for (
auto&& infoPtr : splitGraph.second)
471 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
472 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
473 layers.push_back(infoPtr->m_Layer);
481 const LayerGuid guidB = b->GetOwningLayer().GetGuid();
486 else if (guidA == guidB)
495 const LayerGuid guidB = b->GetOwningLayer().GetGuid();
500 else if (guidA == guidB)
509 result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
void AssignSplitId(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
Layer & GetOwningLayer() const
std::vector< OutputSlot * > OutputSlots
std::function< bool(const Layer &)> LayerSelectorFunction
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
const std::vector< InputSlot > & GetInputSlots() const
bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
The SubgraphView class represents a subgraph of a Graph.
#define ARMNN_ASSERT_MSG(COND, MSG)
std::vector< SubgraphViewPtr > Subgraphs
#define ARMNN_ASSERT(COND)
const std::vector< OutputSlot > & GetOutputSlots() const
std::vector< InputSlot * > InputSlots
static Subgraphs SelectSubgraphs(Graph &graph, const LayerSelectorFunction &selector)
Selects subgraphs from a graph based on the selector function and the algorithm.
profiling::ProfilingGuid LayerGuid
Define LayerGuid type.
std::shared_ptr< PartialSubgraph > m_Subgraph
Which subgraph this layer has been assigned to.
std::list< Layer * > Layers
void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
LayerGuid GetGuid() const final
Returns the unique id of the layer.
unsigned int CalculateIndexOnOwner() const override