8 #include <boost/assert.hpp> 12 #include <unordered_set> 43 PartialSubgraph* GetRepresentative()
46 if (m_Parent ==
nullptr)
52 PartialSubgraph* result = m_Parent->GetRepresentative();
61 void MergeWith(PartialSubgraph* other)
63 if (m_Parent ==
nullptr)
65 other = other->GetRepresentative();
77 for (PartialSubgraph* a : m_Antecedents)
79 size_t numErased = a->m_Dependants.erase(
this);
80 BOOST_ASSERT(numErased == 1);
81 boost::ignore_unused(numErased);
82 a->m_Dependants.insert(m_Parent);
84 for (PartialSubgraph* a : m_Dependants)
86 size_t numErased = a->m_Antecedents.erase(
this);
87 BOOST_ASSERT(numErased == 1);
88 boost::ignore_unused(numErased);
89 a->m_Antecedents.insert(m_Parent);
94 m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
95 m_Antecedents.clear();
96 m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
102 GetRepresentative()->MergeWith(other);
107 bool IsMergedWith(PartialSubgraph* other)
109 return GetRepresentative() == other->GetRepresentative();
113 void AddDirectAntecedent(PartialSubgraph* antecedent)
115 if (m_Parent ==
nullptr)
117 antecedent = antecedent->GetRepresentative();
119 m_Antecedents.insert(antecedent);
122 m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
124 for (PartialSubgraph* d : m_Dependants)
126 d->m_Antecedents.insert(antecedent);
127 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
132 antecedent->m_Dependants.insert(
this);
133 antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
134 for (PartialSubgraph* a : antecedent->m_Antecedents)
136 a->m_Dependants.insert(
this);
137 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
143 GetRepresentative()->AddDirectAntecedent(antecedent);
148 bool HasAntecedent(PartialSubgraph* antecedent)
150 if (m_Parent ==
nullptr)
152 antecedent = antecedent->GetRepresentative();
154 return m_Antecedents.count(antecedent) > 0;
159 return GetRepresentative()->HasAntecedent(antecedent);
165 PartialSubgraph* m_Parent;
167 std::unordered_set<PartialSubgraph*> m_Antecedents;
169 std::unordered_set<PartialSubgraph*> m_Dependants;
173 struct LayerSelectionInfo
175 using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
176 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
186 bool IsInputLayer()
const 191 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
194 for (
auto&& slot =
m_Layer->BeginInputSlots(); slot !=
m_Layer->EndInputSlots(); ++slot)
196 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
197 BOOST_ASSERT_MSG(parentLayerOutputSlot !=
nullptr,
"The input slots must be connected here.");
198 if (parentLayerOutputSlot)
200 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
201 auto parentInfo = layerInfos.find(&parentLayer);
202 if (parentInfo == layerInfos.end() ||
203 !
m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
206 InputSlot* inputSlot = &(*slot);
207 if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
209 inputSlots.push_back(inputSlot);
216 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
219 for (
auto&& slot =
m_Layer->BeginOutputSlots(); slot !=
m_Layer->EndOutputSlots(); ++slot)
221 for (InputSlot* childLayerInputSlot : slot->GetConnections())
223 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
224 auto childInfo = layerInfos.find(&childLayer);
225 if (childInfo == layerInfos.end() ||
226 !
m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
229 OutputSlot* outputSlot = &(*slot);
230 if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
232 outputSlots.push_back(outputSlot);
258 template<
typename Delegate>
260 LayerSelectionInfo& layerInfo,
263 Layer& layer = *layerInfo.m_Layer;
267 auto connectedInput = boost::polymorphic_downcast<OutputSlot*>(inputSlot.GetConnection());
268 BOOST_ASSERT_MSG(connectedInput,
"Dangling input slot detected.");
269 Layer& inputLayer = connectedInput->GetOwningLayer();
271 auto parentInfo = layerInfos.find(&inputLayer);
272 if (parentInfo != layerInfos.end())
274 function(parentInfo->second);
279 template<
typename Delegate>
281 LayerSelectionInfo& layerInfo,
284 Layer& layer= *layerInfo.m_Layer;
288 for (
auto& output : outputSlot.GetConnections())
290 Layer& childLayer = output->GetOwningLayer();
292 auto childInfo = layerInfos.find(&childLayer);
293 if (childInfo != layerInfos.end())
295 function(childInfo->second);
301 void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
307 if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
319 bool dependenciesOk = true;
320 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
324 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
326 dependenciesOk = false;
334 if (layerInfo.m_Subgraph ==
nullptr)
336 layerInfo.m_Subgraph = parentInfo.m_Subgraph;
342 layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
349 if (layerInfo.m_Subgraph ==
nullptr)
351 layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
359 if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
361 layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
370 [&ready](LayerSelectionInfo& parentInfo)
372 if (!parentInfo.m_IsProcessed)
383 LayerSelectionInfo::LayerInfoContainer layerInfos;
385 LayerSelectionInfo::LayerInfoQueue processQueue;
386 for (
auto& layer : subgraph)
388 auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector});
389 LayerSelectionInfo& layerInfo = emplaced.first->second;
392 if (layerInfo.IsInputLayer())
394 processQueue.push(&layerInfo);
399 for (
auto& inputSlot : subgraphInputSlots)
401 Layer& layer = inputSlot->GetOwningLayer();
402 auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
403 LayerSelectionInfo& layerInfo = emplaced.first->second;
405 processQueue.push(&layerInfo);
408 while (!processQueue.empty())
410 LayerSelectionInfo& layerInfo = *processQueue.front();
414 if (!layerInfo.m_IsProcessed)
420 processQueue.push(&layerInfo);
428 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
430 processQueue.push(&childInfo);
434 layerInfo.m_IsProcessed =
true;
439 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
440 std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
441 for (
auto&
info : layerInfos)
443 if (
info.second.m_IsSelected)
445 auto it = splitMap.find(
info.second.m_Subgraph->GetRepresentative());
446 if (it == splitMap.end())
449 std::make_pair(
info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&
info.second}));
453 it->second.push_back(&
info.second);
460 for (
auto& splitGraph : splitMap)
465 for (
auto&& infoPtr : splitGraph.second)
467 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
468 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
469 layers.push_back(infoPtr->m_Layer);
472 result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
std::list< Layer * > Layers
void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
const std::vector< OutputSlot > & GetOutputSlots() const
static Subgraphs SelectSubgraphs(Graph &graph, const LayerSelectorFunction &selector)
std::shared_ptr< PartialSubgraph > m_Subgraph
const std::vector< InputSlot > & GetInputSlots() const
std::vector< InputSlot * > InputSlots
std::function< bool(const Layer &)> LayerSelectorFunction
std::vector< SubgraphViewPtr > Subgraphs
std::vector< OutputSlot * > OutputSlots
void AssignSplitId(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)