ArmNN
 20.02
SubgraphViewSelector.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 #include "Graph.hpp"
8 
10 
11 #include <boost/assert.hpp>
12 #include <algorithm>
13 #include <map>
14 #include <queue>
15 #include <unordered_set>
16 
17 namespace armnn
18 {
19 
20 namespace
21 {
22 
23 /// Intermediate data-structure to store the subgraph that a layer has been assigned to.
24 /// This is a "disjoint set" data structure that allows efficient merging of subgraphs,
25 /// which is a key part of the algorithm. Subgraphs are arranged in singly-linked trees
26 /// (with each node storing a pointer to its parent). Subgraphs in the same tree are considered
27 /// to have been merged. Merging subgraphs is performed by attaching one tree to another,
28 /// which is a simple pointer update.
29 ///
30 /// NOTE: Due to the way this is stored, it is almost never correct to directly compare pointers
31 /// to two PartialSubgraphs to check if two layers belong in the same subgraph. Instead you
32 /// should use IsMergedWith().
33 ///
34 /// This structure also stores information about the dependencies of each subgraph, which is needed
35 /// to determine whether certain subgraphs can be merged. Checking whether a subgraph
36 /// depends on another subgraph is a frequent operation in the algorithm (see AssignSplitId) and so this is optimized
37 /// in preference to the merging of subgraphs. This leads to an approach where each subgraph stores
38 /// a set of all the subgraphs it depends on (for a fast lookup). In order to efficiently update this
39 /// set as subgraphs are merged means we also store a set of subgraphs which *depend on us* (i.e. the
40 /// complement of our dependencies).
41 class PartialSubgraph
42 {
43 public:
44  /// If this subgraph has been merged with another then there is an agreed "representative" for the combined
45  /// subgraph, which uniquely identifies the subgraph.
46  PartialSubgraph* GetRepresentative()
47  {
48  // Recurse up the tree to find the root node.
49  if (m_Parent == nullptr)
50  {
51  return this;
52  }
53  else
54  {
55  PartialSubgraph* result = m_Parent->GetRepresentative();
56  // Update our parent pointer to point directly to the root in order to speed up future calls to this method.
57  // This essentially "flattens" the tree.
58  m_Parent = result;
59  return result;
60  }
61  }
62 
63  /// Merges this subgraph with another.
64  void MergeWith(PartialSubgraph* other)
65  {
66  if (m_Parent == nullptr)
67  {
68  other = other->GetRepresentative();
69  if (this == other)
70  {
71  // Already merged - no-op
72  return;
73  }
74  m_Parent = other;
75 
76  // Update others' dependency sets to point to the new representative rather than us.
77  // Keeping these up-to-date means we can rely on these sets containing representatives when
78  // we perform a lookup in HasAntecedent() and so don't need to resolve the representative for each element
79  // of the set. See description at the top of this class for more rationale.
80  for (PartialSubgraph* a : m_Antecedents)
81  {
82  size_t numErased = a->m_Dependants.erase(this);
83  BOOST_ASSERT(numErased == 1);
84  IgnoreUnused(numErased);
85  a->m_Dependants.insert(m_Parent);
86  }
87  for (PartialSubgraph* a : m_Dependants)
88  {
89  size_t numErased = a->m_Antecedents.erase(this);
90  BOOST_ASSERT(numErased == 1);
91  IgnoreUnused(numErased);
92  a->m_Antecedents.insert(m_Parent);
93  }
94 
95  // Merge our dependency sets into our new representative.
96  // We no longer need to maintain our own sets, as requests will always be forwarded to the representative.
97  m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
98  m_Antecedents.clear();
99  m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
100  m_Dependants.clear();
101  }
102  else
103  {
104  // Defer request to the representative
105  GetRepresentative()->MergeWith(other);
106  }
107  }
108 
109  /// Checks if this subgraph has been merged with the given subgraph.
110  bool IsMergedWith(PartialSubgraph* other)
111  {
112  return GetRepresentative() == other->GetRepresentative();
113  }
114 
115  /// Marks the given subgraph as a direct antecedent (dependency) of this one.
116  void AddDirectAntecedent(PartialSubgraph* antecedent)
117  {
118  if (m_Parent == nullptr)
119  {
120  antecedent = antecedent->GetRepresentative();
121 
122  m_Antecedents.insert(antecedent);
123  // Also record all of its antecedents, so that we end up with direct and indirect antecedents.
124  // This makes the lookup in HasAntecedent() faster.
125  m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
126  // All of our dependents also need to include the new antecedents
127  for (PartialSubgraph* d : m_Dependants)
128  {
129  d->m_Antecedents.insert(antecedent);
130  d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
131  }
132 
133  // Store reverse dependencies as well, required so that we can efficiently navigate the graph
134  // when making updates.
135  antecedent->m_Dependants.insert(this);
136  antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
137  for (PartialSubgraph* a : antecedent->m_Antecedents)
138  {
139  a->m_Dependants.insert(this);
140  a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
141  }
142  }
143  else
144  {
145  // Defer request to the representative
146  GetRepresentative()->AddDirectAntecedent(antecedent);
147  }
148  }
149 
150  /// Checks if this subgraph is dependent on the given subgraph, either directly or indirectly.
151  bool HasAntecedent(PartialSubgraph* antecedent)
152  {
153  if (m_Parent == nullptr)
154  {
155  antecedent = antecedent->GetRepresentative();
156  // Thanks to keeping this set updated in MergeWith and AddDirectAntecedent, we can do an efficient lookup.
157  return m_Antecedents.count(antecedent) > 0;
158  }
159  else
160  {
161  // Defer request to the representative
162  return GetRepresentative()->HasAntecedent(antecedent);
163  }
164  }
165 
166 private:
167  /// Pointer to the parent node in the tree. If this is null then we are the representative for our merged subgraph.
168  PartialSubgraph* m_Parent;
169  /// The representatives of all the subgraphs which we depend on, either directly or indirectly.
170  std::unordered_set<PartialSubgraph*> m_Antecedents;
171  /// The representatives of all the subgraphs which depend on us, either directly or indirectly.
172  std::unordered_set<PartialSubgraph*> m_Dependants;
173 };
174 
175 /// Intermediate data structure to store information associated with a particular layer.
176 struct LayerSelectionInfo
177 {
178  using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
179  using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
180 
181  LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
182  : m_Layer{layer}
183  , m_Subgraph{nullptr}
184  , m_IsSelected{selector(*layer)}
185  , m_IsProcessed(false)
186  {
187  }
188 
189  bool IsInputLayer() const
190  {
191  return m_Layer->GetType() == armnn::LayerType::Input || m_Layer->GetType() == armnn::LayerType::Constant;
192  }
193 
194  void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
195  SubgraphView::InputSlots& inputSlots)
196  {
197  for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot)
198  {
199  OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
200  BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
201  if (parentLayerOutputSlot)
202  {
203  Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
204  auto parentInfo = layerInfos.find(&parentLayer);
205  if (parentInfo == layerInfos.end() ||
206  !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
207  {
208  // Avoid collecting duplicate input slots
209  InputSlot* inputSlot = &(*slot);
210  if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
211  {
212  inputSlots.push_back(inputSlot);
213  }
214  }
215  }
216  }
217  }
218 
219  void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
220  SubgraphView::OutputSlots& outputSlots)
221  {
222  for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
223  {
224  for (InputSlot* childLayerInputSlot : slot->GetConnections())
225  {
226  Layer& childLayer = childLayerInputSlot->GetOwningLayer();
227  auto childInfo = layerInfos.find(&childLayer);
228  if (childInfo == layerInfos.end() ||
229  !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
230  {
231  // Avoid collecting duplicate output slots
232  OutputSlot* outputSlot = &(*slot);
233  if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
234  {
235  outputSlots.push_back(outputSlot);
236  }
237  }
238  }
239  }
240  }
241 
242  Layer* m_Layer;
243  /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true.
244  /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph -
245  /// see the description of the PartialSubgraph class.
246  std::shared_ptr<PartialSubgraph> m_Subgraph;
249 };
250 
251 } // namespace <anonymous>
252 
255 {
256  SubgraphView subgraph(graph);
257  return SubgraphViewSelector::SelectSubgraphs(subgraph, selector);
258 }
259 
260 
261 template<typename Delegate>
262 void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
263  LayerSelectionInfo& layerInfo,
264  Delegate function)
265 {
266  Layer& layer = *layerInfo.m_Layer;
267 
268  for (auto inputSlot : layer.GetInputSlots())
269  {
270  auto connectedInput = boost::polymorphic_downcast<OutputSlot*>(inputSlot.GetConnection());
271  BOOST_ASSERT_MSG(connectedInput, "Dangling input slot detected.");
272  Layer& inputLayer = connectedInput->GetOwningLayer();
273 
274  auto parentInfo = layerInfos.find(&inputLayer);
275  if (parentInfo != layerInfos.end())
276  {
277  function(parentInfo->second);
278  }
279  }
280 }
281 
282 template<typename Delegate>
283 void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
284  LayerSelectionInfo& layerInfo,
285  Delegate function)
286 {
287  Layer& layer= *layerInfo.m_Layer;
288 
289  for (auto& outputSlot : layer.GetOutputSlots())
290  {
291  for (auto& output : outputSlot.GetConnections())
292  {
293  Layer& childLayer = output->GetOwningLayer();
294 
295  auto childInfo = layerInfos.find(&childLayer);
296  if (childInfo != layerInfos.end())
297  {
298  function(childInfo->second);
299  }
300  }
301  }
302 }
303 
304 void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
305 {
306  // Check each input to see if we can attach ourselves to any of the subgraphs that have already been assigned.
307  ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
308  {
309  // We can only attach ourselves to the subgraph from this input if there isn't a cut here.
310  if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
311  {
312  // We also need to check that merging into this subgraph won't cause a dependency cycle between subgraphs.
313  // This will be the case if the subgraph that we will become part of is already a dependency
314  // of one of the subgraphs that are input to this layer, e.g:
315  //
316  // 0 | The numbers (0, 1) are the subgraph IDs of each layer and we are looking at layer X.
317  // / \ |
318  // 1 0 | We can't merge X into subgraph 0, because the left-hand input already depends on subgraph 0.
319  // \ / | We can however merge X into subgraph 1.
320  // X |
321  //
322  bool dependenciesOk = true;
323  ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
324  {
325  // We call HasAntecedent() ~ n^2 times, where n is the number of inputs to this layer.
326  // Hence it is important that this is efficient - see PartialSubgraph class description.
327  if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
328  {
329  dependenciesOk = false;
330  }
331  });
332 
333  if (dependenciesOk)
334  {
335  // Merge into the subgraph of this input. If we have already been merged into another subgraph
336  // (from another input of this layer), then merge both of them together.
337  if (layerInfo.m_Subgraph == nullptr)
338  {
339  layerInfo.m_Subgraph = parentInfo.m_Subgraph;
340  }
341  else
342  {
343  // We call MergeWith() ~ n times, where n is the number of inputs to this layer.
344  // Therefore it does not need to be as performant as HasAntecedent().
345  layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
346  }
347  }
348  }
349  });
350 
351  // If we weren't able to merge into an existing subgraph then we need to make a new one
352  if (layerInfo.m_Subgraph == nullptr)
353  {
354  layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
355  }
356 
357  // Record dependencies of the chosen subgraph based on the inputs of this layer.
358  ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
359  {
360  // These functions are called ~n times, where n is the number of inputs to this layer.
361  // Therefore it does not need to be as performant as HasAntecedent().
362  if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
363  {
364  layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
365  }
366  });
367 }
368 
369 bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
370 {
371  bool ready = true;
372  ForEachLayerInput(layerInfos, layerInfo,
373  [&ready](LayerSelectionInfo& parentInfo)
374  {
375  if (!parentInfo.m_IsProcessed)
376  {
377  ready = false;
378  }
379  });
380  return ready;
381 }
382 
385 {
386  LayerSelectionInfo::LayerInfoContainer layerInfos;
387 
388  LayerSelectionInfo::LayerInfoQueue processQueue;
389  for (auto& layer : subgraph)
390  {
391  auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector});
392  LayerSelectionInfo& layerInfo = emplaced.first->second;
393 
394  // Start with Input type layers
395  if (layerInfo.IsInputLayer())
396  {
397  processQueue.push(&layerInfo);
398  }
399  }
400 
401  const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots();
402  for (auto& inputSlot : subgraphInputSlots)
403  {
404  Layer& layer = inputSlot->GetOwningLayer();
405  auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
406  LayerSelectionInfo& layerInfo = emplaced.first->second;
407 
408  processQueue.push(&layerInfo);
409  }
410 
411  while (!processQueue.empty())
412  {
413  LayerSelectionInfo& layerInfo = *processQueue.front();
414  processQueue.pop(); // remove front from queue
415 
416  // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it
417  if (!layerInfo.m_IsProcessed)
418  {
419  // Only process this layerInfo if all inputs have been processed
420  if (!IsReadyForSplitAssignment(layerInfos, layerInfo))
421  {
422  // Put back of the process queue if we can't process it just yet
423  processQueue.push(&layerInfo);
424  continue; // Skip to next iteration
425  }
426 
427  // Now we do the processing
428  AssignSplitId(layerInfos, layerInfo);
429 
430  // Queue any child nodes for processing
431  ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
432  {
433  processQueue.push(&childInfo);
434  });
435 
436  // We don't need to process this node again
437  layerInfo.m_IsProcessed = true;
438  }
439  }
440 
441  // Collect all selected layers keyed by subgraph representative into a map
442  using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
443  std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
444  for (auto& info : layerInfos)
445  {
446  if (info.second.m_IsSelected)
447  {
448  auto it = splitMap.find(info.second.m_Subgraph->GetRepresentative());
449  if (it == splitMap.end())
450  {
451  splitMap.insert(
452  std::make_pair(info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second}));
453  }
454  else
455  {
456  it->second.push_back(&info.second);
457  }
458  }
459  }
460 
461  // Now each entry in splitMap represents a subgraph
462  Subgraphs result;
463  for (auto& splitGraph : splitMap)
464  {
467  SubgraphView::Layers layers;
468  for (auto&& infoPtr : splitGraph.second)
469  {
470  infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
471  infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
472  layers.push_back(infoPtr->m_Layer);
473  }
474  // Create a new sub-graph with the new lists of input/output slots and layer
475  result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
476  std::move(outputs),
477  std::move(layers)));
478  }
479 
480  return result;
481 }
482 
483 } // namespace armnn
bool m_IsSelected
void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)
void AssignSplitId(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
std::vector< OutputSlot * > OutputSlots
std::function< bool(const Layer &)> LayerSelectorFunction
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
const std::vector< InputSlot > & GetInputSlots() const
Definition: Layer.hpp:231
bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo)
The SubgraphView class represents a subgraph of a Graph.
Layer * m_Layer
std::vector< SubgraphViewPtr > Subgraphs
const std::vector< OutputSlot > & GetOutputSlots() const
Definition: Layer.hpp:232
std::vector< InputSlot * > InputSlots
bool m_IsProcessed
static Subgraphs SelectSubgraphs(Graph &graph, const LayerSelectorFunction &selector)
Selects subgraphs from a graph based on the selector function and the algorithm.
std::shared_ptr< PartialSubgraph > m_Subgraph
Which subgraph this layer has been assigned to.
std::list< Layer * > Layers
void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer &layerInfos, LayerSelectionInfo &layerInfo, Delegate function)