aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphViewSelector.cpp
blob: 4357ec4381d7405e55e416c92a547cea9298891c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "SubgraphViewSelector.hpp"
#include "Graph.hpp"
#include <boost/assert.hpp>
#include <algorithm>
#include <map>
#include <queue>

namespace armnn
{

namespace
{

struct LayerSelectionInfo
{
    using SplitId = uint32_t;
    using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
    using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
    static constexpr uint32_t InitialSplitId() { return 1; }

    LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
    : m_Layer{layer}
    , m_SplitId{0}
    , m_IsSelected{selector(*layer)}
    , m_IsProcessed(false)
    {
        // fill topology information by storing direct children
        for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
        {
            for (InputSlot* childLayerInputSlot : slot->GetConnections())
            {
                Layer& childLayer = childLayerInputSlot->GetOwningLayer();
                m_DirectChildren.push_back(&childLayer);
            }
        }
    }

    bool IsInputLayer() const
    {
        return m_Layer->GetType() == armnn::LayerType::Input || m_Layer->GetType() == armnn::LayerType::Constant;
    }

    void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
                                  SubgraphView::InputSlots& inputSlots)
    {
        for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot)
        {
            OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
            BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
            if (parentLayerOutputSlot)
            {
                Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
                auto parentInfo = layerInfos.find(&parentLayer);
                if (parentInfo == layerInfos.end() ||
                        m_SplitId != parentInfo->second.m_SplitId)
                {
                    // Avoid collecting duplicate input slots
                    InputSlot* inputSlot = &(*slot);
                    if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
                    {
                        inputSlots.push_back(inputSlot);
                    }
                }
            }
        }
    }

    void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
                                       SubgraphView::OutputSlots& outputSlots)
    {
        for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
        {
            for (InputSlot* childLayerInputSlot : slot->GetConnections())
            {
                Layer& childLayer = childLayerInputSlot->GetOwningLayer();
                auto childInfo = layerInfos.find(&childLayer);
                if (childInfo == layerInfos.end() ||
                        m_SplitId != childInfo->second.m_SplitId)
                {
                    // Avoid collecting duplicate output slots
                    OutputSlot* outputSlot = &(*slot);
                    if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
                    {
                        outputSlots.push_back(outputSlot);
                    }
                }
            }
        }
    }

    std::vector<Layer*> m_DirectChildren;
    Layer* m_Layer;
    SplitId m_SplitId;
    bool m_IsSelected;
    bool m_IsProcessed;
};

} // namespace <anonymous>

SubgraphViewSelector::Subgraphs
SubgraphViewSelector::SelectSubgraphs(Graph& graph, const LayerSelectorFunction& selector)
{
    SubgraphView subgraph(graph);
    return SubgraphViewSelector::SelectSubgraphs(subgraph, selector);
}


template<typename Delegate>
void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
                       LayerSelectionInfo& layerInfo,
                       Delegate function)
{
    Layer& layer = *layerInfo.m_Layer;

    for (auto inputSlot : layer.GetInputSlots())
    {
        auto connectedInput = boost::polymorphic_downcast<OutputSlot*>(inputSlot.GetConnection());
        BOOST_ASSERT_MSG(connectedInput, "Dangling input slot detected.");
        Layer& inputLayer = connectedInput->GetOwningLayer();

        auto parentInfo = layerInfos.find(&inputLayer);
        if (parentInfo != layerInfos.end())
        {
            function(parentInfo->second);
        }
    }
}

template<typename Delegate>
void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
                        LayerSelectionInfo& layerInfo,
                        Delegate function)
{
    Layer& layer= *layerInfo.m_Layer;

    for (auto& outputSlot : layer.GetOutputSlots())
    {
        for (auto& output : outputSlot.GetConnections())
        {
            Layer& childLayer = output->GetOwningLayer();

            auto childInfo = layerInfos.find(&childLayer);
            if (childInfo != layerInfos.end())
            {
                function(childInfo->second);
            }
        }
    }
}

void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
{
    bool newSplit = false;
    LayerSelectionInfo::SplitId minSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::max();
    LayerSelectionInfo::SplitId maxSplitId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest();
    LayerSelectionInfo::SplitId maxSelectableId = std::numeric_limits<LayerSelectionInfo::SplitId>::lowest();

    ForEachLayerInput(layerInfos, layerInfo, [&newSplit, &minSplitId, &maxSplitId, &maxSelectableId, &layerInfo](
        LayerSelectionInfo& parentInfo)
        {
            minSplitId = std::min(minSplitId, parentInfo.m_SplitId);
            maxSplitId = std::max(maxSplitId, parentInfo.m_SplitId);
            if (parentInfo.m_IsSelected && layerInfo.m_IsSelected)
            {
                maxSelectableId = std::max(maxSelectableId, parentInfo.m_SplitId);
            }

            if (layerInfo.m_IsSelected != parentInfo.m_IsSelected)
            {
                newSplit = true;
            }

        });

    // Assign the split Id for the current layerInfo
    if (newSplit)
    {
        if (maxSelectableId > minSplitId)
        {
            // We can be overly aggressive when choosing to create a new split so
            // here we determine if one of the parent branches are suitable candidates for continuation instead.
            // Any splitId > minSplitId will come from a shorter branch...and therefore should not be from
            // the split containing the original fork and thus we avoid the execution dependency.
            layerInfo.m_SplitId = maxSelectableId;
        }
        else
        {
            layerInfo.m_SplitId = ++maxSplitId;
        }
    } else
    {
        // The branch with the highest splitId represents the shortest path of selected nodes.
        layerInfo.m_SplitId = maxSplitId;
    }
}

bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
{
    bool ready = true;
    ForEachLayerInput(layerInfos, layerInfo,
                      [&ready](LayerSelectionInfo& parentInfo)
                          {
                              if (!parentInfo.m_IsProcessed)
                              {
                                  ready = false;
                              }
                          });
    return ready;
}

SubgraphViewSelector::Subgraphs
SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelectorFunction& selector)
{
    LayerSelectionInfo::LayerInfoContainer layerInfos;

    LayerSelectionInfo::LayerInfoQueue processQueue;
    for (auto& layer : subgraph)
    {
        auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector});
        LayerSelectionInfo& layerInfo = emplaced.first->second;

        // Start with Input type layers
        if (layerInfo.IsInputLayer())
        {
            processQueue.push(&layerInfo);
        }
    }

    const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots();
    for (auto& inputSlot : subgraphInputSlots)
    {
        Layer& layer = inputSlot->GetOwningLayer();
        auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
        LayerSelectionInfo& layerInfo = emplaced.first->second;

        processQueue.push(&layerInfo);
    }

    while (!processQueue.empty())
    {
        LayerSelectionInfo& layerInfo = *processQueue.front();
        processQueue.pop(); // remove front from queue

        // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it
        if (!layerInfo.m_IsProcessed)
        {

            // Only process this layerInfo if all inputs have been processed
            if (!IsReadyForSplitAssignment(layerInfos, layerInfo))
            {
                // Put back of the process queue if we can't process it just yet
                processQueue.push(&layerInfo);
                continue; // Skip to next iteration
            }

            // Now we do the processing
            AssignSplitId(layerInfos, layerInfo);

            // Queue any child nodes for processing
            ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
                {
                    processQueue.push(&childInfo);
                });

            // We don't need to process this node again
            layerInfo.m_IsProcessed = true;
        }
    }

    // Collect all selected layers keyed by split id into a map
    using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
    std::map<uint32_t, SelectionInfoPtrs> splitMap;
    for (auto& info : layerInfos)
    {
        if (info.second.m_IsSelected)
        {
            auto it = splitMap.find(info.second.m_SplitId);
            if (it == splitMap.end())
            {
                splitMap.insert(std::make_pair(info.second.m_SplitId, SelectionInfoPtrs{&info.second}));
            }
            else
            {
                it->second.push_back(&info.second);
            }
        }
    }

    // Now each non-empty split id represents a subgraph
    Subgraphs result;
    for (auto& splitGraph : splitMap)
    {
        if (splitGraph.second.empty() == false)
        {
            SubgraphView::InputSlots inputs;
            SubgraphView::OutputSlots outputs;
            SubgraphView::Layers layers;
            for (auto&& infoPtr : splitGraph.second)
            {
                infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
                infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
                layers.push_back(infoPtr->m_Layer);
            }
            // Create a new sub-graph with the new lists of input/output slots and layer
            result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
                                                               std::move(outputs),
                                                               std::move(layers)));
        }
    }

    return result;
}

} // namespace armnn