aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/backends/SubgraphView.hpp
blob: 33593319cfc5cbdda1c32b1c60f6e6948e06bb7c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <Layer.hpp>
#include <Graph.hpp>

#include <vector>
#include <list>

namespace armnn
{

///
/// The SubgraphView class represents a subgraph of a Graph.
/// The data it holds, points to data held by layers of the Graph, so the
/// the contents of the SubgraphView become invalid when the Layers are destroyed
/// or changed.
///
class SubgraphView final
{
public:
    template <typename Func>
    void ForEachLayer(Func func) const
    {
        for (auto it = m_Layers.begin(); it != m_Layers.end(); )
        {
             auto next = std::next(it);
             func(*it);
             it = next;
        }
    }

    template <typename Func>
    void ForEachIConnectableLayer(Func func) const
    {
        for (auto it = m_IConnectableLayers.begin(); it != m_IConnectableLayers.end(); )
        {
             auto next = std::next(it);
             func(*it);
             it = next;
        }
    }

    using SubgraphViewPtr = std::unique_ptr<SubgraphView>;
    using InputSlots = std::vector<InputSlot*>;
    using IInputSlots = std::vector<IInputSlot*>;
    using OutputSlots = std::vector<OutputSlot*>;
    using IOutputSlots = std::vector<IOutputSlot*>;
    using Layers = std::list<Layer*>;
    using IConnectableLayers = std::list<IConnectableLayer*>;
    using Iterator = Layers::iterator;
    using IConnectableLayerIterator = IConnectableLayers::iterator;
    using ConstIterator = Layers::const_iterator;
    using ConstIConnectableIterator = IConnectableLayers::const_iterator;

    /// Constructs a sub-graph from the entire given graph.
    explicit SubgraphView(Graph& graph);

    /// Constructs a sub-graph with the given arguments.
    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use constructor with arguments: "
                                      "IConnectableLayers, IInputSlots and IOutputSlots", "22.11")
    SubgraphView(InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers);

    /// Constructs a sub-graph with the given arguments.
    SubgraphView(IConnectableLayers&& layers, IInputSlots&& inputs, IOutputSlots&& outputs);

    /// Copy-constructor.
    SubgraphView(const SubgraphView& subgraph);

    /// Move-constructor.
    SubgraphView(SubgraphView&& subgraph);

    /// Constructs a sub-graph with only the given layer.
    SubgraphView(IConnectableLayer* layer);

    /// Move-assignment operator.
    SubgraphView& operator=(SubgraphView&& other);

    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use GetIInputSlots() returning"
                                      " public IInputSlots", "22.11")
    const InputSlots& GetInputSlots() const;
    const IInputSlots& GetIInputSlots() const;

    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use GetIOutputSlots() returning"
                                      " public IOutputSlots", "22.11")
    const OutputSlots& GetOutputSlots() const;
    const IOutputSlots& GetIOutputSlots() const;

    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use GetIConnectableLayers() "
                                      "returning public IConnectableLayers", "22.11")
    const Layers& GetLayers() const;
    const IConnectableLayers& GetIConnectableLayers() const;

    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use GetIInputSlot() returning public "
                                      "IInputSlot", "22.11")
    const InputSlot* GetInputSlot(unsigned int index) const;
    const IInputSlot* GetIInputSlot(unsigned int index) const;
    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use GetIInputSlot() returning public "
                                      "IInputSlot", "22.11")
    InputSlot* GetInputSlot(unsigned int index);
    IInputSlot* GetIInputSlot(unsigned int index);

    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use GetIOutputSlot() returning"
                                      " public IOutputSlot", "22.11")
    const OutputSlot* GetOutputSlot(unsigned int index) const;
    const IOutputSlot* GetIOutputSlot(unsigned int index) const;
    ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This function has been deprecated, please use GetIOutputSlot() returning"
                                      " public IOutputSlot", "22.11")
    OutputSlot* GetOutputSlot(unsigned int index);
    IOutputSlot* GetIOutputSlot(unsigned int index);

    unsigned int GetNumInputSlots() const;
    unsigned int GetNumOutputSlots() const;

    ARMNN_DEPRECATED_MSG_CHANGE_DATE("This function is deprecated and will be changed to return an "
                                     "IConnectableLayerIterator, until that occurs in 23.02; please use "
                                     "beginIConnectable() returning public IConnectableLayerIterator", "23.02")
    Iterator begin();
    IConnectableLayerIterator beginIConnectable();
    ARMNN_DEPRECATED_MSG_CHANGE_DATE("This function is deprecated and will be changed to return an "
                                     "IConnectableLayerIterator, until that occurs in 23.02; please use "
                                     "endIConnectable() returning public IConnectableLayerIterator", "23.02")
    Iterator end();
    IConnectableLayerIterator endIConnectable();

    ARMNN_DEPRECATED_MSG_CHANGE_DATE("This function is deprecated and will be changed to return an "
                                     "ConstIConnectableIterator, until that occurs in 23.02; please use "
                                     "beginIConnectable() returning public ConstIConnectableIterator", "23.02")
    ConstIterator begin() const;
    ConstIConnectableIterator beginIConnectable() const;
    ARMNN_DEPRECATED_MSG_CHANGE_DATE("This function is deprecated and will be changed to return an "
                                     "ConstIConnectableIterator, until that occurs in 23.02; please use "
                                     "endIConnectable() returning public ConstIConnectableIterator", "23.02")
    ConstIterator end() const;
    ConstIConnectableIterator endIConnectable() const;

    ARMNN_DEPRECATED_MSG_CHANGE_DATE("This function is deprecated and will be changed to return an "
                                     "ConstIConnectableIterator, until that occurs in 23.02; please use "
                                     "cbeginIConnectable() returning public ConstIConnectableIterator", "23.02")
    ConstIterator cbegin() const;
    ConstIConnectableIterator cbeginIConnectable() const;
    ARMNN_DEPRECATED_MSG_CHANGE_DATE("This function is deprecated and will be changed to return an "
                                     "ConstIConnectableIterator, until that occurs in 23.02; please use "
                                     "cendIConnectable() returning public ConstIConnectableIterator", "23.02")
    ConstIterator cend() const;
    ConstIConnectableIterator cendIConnectable() const;

    void Clear();

private:
    void CheckSubgraph();

    /// Arrange the order of layers topologically so that nodes can be visited in valid order
    void ArrangeBySortOrder();

    /// The list of pointers to the input slots of the parent graph.
    InputSlots m_InputSlots;
    IInputSlots m_IInputSlots;

    /// The list of pointers to the output slots of the parent graph.
    OutputSlots m_OutputSlots;
    IOutputSlots m_IOutputSlots;

    /// The list of pointers to the layers of the parent graph.
    Layers m_Layers;
    IConnectableLayers m_IConnectableLayers;
};
} // namespace armnn