11 #include <boost/assert.hpp> 15 #include <unordered_set> 46 PartialSubgraph* GetRepresentative()
49 if (m_Parent ==
nullptr)
55 PartialSubgraph* result = m_Parent->GetRepresentative();
64 void MergeWith(PartialSubgraph* other)
66 if (m_Parent ==
nullptr)
68 other = other->GetRepresentative();
80 for (PartialSubgraph* a : m_Antecedents)
82 size_t numErased = a->m_Dependants.erase(
this);
83 BOOST_ASSERT(numErased == 1);
85 a->m_Dependants.insert(m_Parent);
87 for (PartialSubgraph* a : m_Dependants)
89 size_t numErased = a->m_Antecedents.erase(
this);
90 BOOST_ASSERT(numErased == 1);
92 a->m_Antecedents.insert(m_Parent);
97 m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
98 m_Antecedents.clear();
99 m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
100 m_Dependants.clear();
105 GetRepresentative()->MergeWith(other);
110 bool IsMergedWith(PartialSubgraph* other)
112 return GetRepresentative() == other->GetRepresentative();
116 void AddDirectAntecedent(PartialSubgraph* antecedent)
118 if (m_Parent ==
nullptr)
120 antecedent = antecedent->GetRepresentative();
122 m_Antecedents.insert(antecedent);
125 m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
127 for (PartialSubgraph* d : m_Dependants)
129 d->m_Antecedents.insert(antecedent);
130 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
135 antecedent->m_Dependants.insert(
this);
136 antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
137 for (PartialSubgraph* a : antecedent->m_Antecedents)
139 a->m_Dependants.insert(
this);
140 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
146 GetRepresentative()->AddDirectAntecedent(antecedent);
151 bool HasAntecedent(PartialSubgraph* antecedent)
153 if (m_Parent ==
nullptr)
155 antecedent = antecedent->GetRepresentative();
157 return m_Antecedents.count(antecedent) > 0;
162 return GetRepresentative()->HasAntecedent(antecedent);
168 PartialSubgraph* m_Parent;
170 std::unordered_set<PartialSubgraph*> m_Antecedents;
172 std::unordered_set<PartialSubgraph*> m_Dependants;
176 struct LayerSelectionInfo
178 using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
179 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
189 bool IsInputLayer()
const 194 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
197 for (
auto&& slot =
m_Layer->BeginInputSlots(); slot !=
m_Layer->EndInputSlots(); ++slot)
199 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
200 BOOST_ASSERT_MSG(parentLayerOutputSlot !=
nullptr,
"The input slots must be connected here.");
201 if (parentLayerOutputSlot)
203 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
204 auto parentInfo = layerInfos.find(&parentLayer);
205 if (parentInfo == layerInfos.end() ||
206 !
m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
209 InputSlot* inputSlot = &(*slot);
210 if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
212 inputSlots.push_back(inputSlot);
219 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
222 for (
auto&& slot =
m_Layer->BeginOutputSlots(); slot !=
m_Layer->EndOutputSlots(); ++slot)
224 for (InputSlot* childLayerInputSlot : slot->GetConnections())
226 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
227 auto childInfo = layerInfos.find(&childLayer);
228 if (childInfo == layerInfos.end() ||
229 !
m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
232 OutputSlot* outputSlot = &(*slot);
233 if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
235 outputSlots.push_back(outputSlot);
261 template<
typename Delegate>
263 LayerSelectionInfo& layerInfo,
266 Layer& layer = *layerInfo.m_Layer;
270 auto connectedInput = boost::polymorphic_downcast<OutputSlot*>(inputSlot.GetConnection());
271 BOOST_ASSERT_MSG(connectedInput,
"Dangling input slot detected.");
272 Layer& inputLayer = connectedInput->GetOwningLayer();
274 auto parentInfo = layerInfos.find(&inputLayer);
275 if (parentInfo != layerInfos.end())
277 function(parentInfo->second);
282 template<
typename Delegate>
284 LayerSelectionInfo& layerInfo,
287 Layer& layer= *layerInfo.m_Layer;
291 for (
auto& output : outputSlot.GetConnections())
293 Layer& childLayer = output->GetOwningLayer();
295 auto childInfo = layerInfos.find(&childLayer);
296 if (childInfo != layerInfos.end())
298 function(childInfo->second);
304 void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
310 if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
322 bool dependenciesOk = true;
323 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
327 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
329 dependenciesOk = false;
337 if (layerInfo.m_Subgraph ==
nullptr)
339 layerInfo.m_Subgraph = parentInfo.m_Subgraph;
345 layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
352 if (layerInfo.m_Subgraph ==
nullptr)
354 layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
362 if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
364 layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
373 [&ready](LayerSelectionInfo& parentInfo)
375 if (!parentInfo.m_IsProcessed)
386 LayerSelectionInfo::LayerInfoContainer layerInfos;
388 LayerSelectionInfo::LayerInfoQueue processQueue;
389 for (
auto& layer : subgraph)
391 auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector});
392 LayerSelectionInfo& layerInfo = emplaced.first->second;
395 if (layerInfo.IsInputLayer())
397 processQueue.push(&layerInfo);
402 for (
auto& inputSlot : subgraphInputSlots)
404 Layer& layer = inputSlot->GetOwningLayer();
405 auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
406 LayerSelectionInfo& layerInfo = emplaced.first->second;
408 processQueue.push(&layerInfo);
411 while (!processQueue.empty())
413 LayerSelectionInfo& layerInfo = *processQueue.front();
417 if (!layerInfo.m_IsProcessed)
423 processQueue.push(&layerInfo);
431 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
433 processQueue.push(&childInfo);
437 layerInfo.m_IsProcessed =
true;
442 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
443 std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
444 for (
auto&
info : layerInfos)
446 if (
info.second.m_IsSelected)
448 auto it = splitMap.find(
info.second.m_Subgraph->GetRepresentative());
449 if (it == splitMap.end())
452 std::make_pair(
info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&
info.second}));
456 it->second.push_back(&
info.second);
463 for (
auto& splitGraph : splitMap)
468 for (
auto&& infoPtr : splitGraph.second)
470 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
471 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
472 layers.push_back(infoPtr->m_Layer);
475 result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
void AssignSplitId(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
std::vector< OutputSlot * > OutputSlots
std::function< bool(const Layer &)> LayerSelectorFunction
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
const std::vector< InputSlot > & GetInputSlots() const
bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
The SubgraphView class represents a subgraph of a Graph.
std::vector< SubgraphViewPtr > Subgraphs
const std::vector< OutputSlot > & GetOutputSlots() const
std::vector< InputSlot * > InputSlots
static Subgraphs SelectSubgraphs(Graph &graph, const LayerSelectorFunction &selector)
Selects subgraphs from a graph based on the selector function and the algorithm.
std::shared_ptr< PartialSubgraph > m_Subgraph
Which subgraph this layer has been assigned to.
std::list< Layer * > Layers
void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)