aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphView.hpp
blob: 1b9c308daa209cd56da2eba1d237d35c352cbb50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "Layer.hpp"
#include "Graph.hpp"

#include <vector>
#include <list>

namespace armnn
{

///
/// The SubgraphView class represents a subgraph of a Graph.
/// The data it holds, points to data held by layers of the Graph, so the
/// the contents of the SubgraphView become invalid when the Layers are destroyed
/// or changed.
///
class SubgraphView final
{
public:
    template <typename Func>
    void ForEachLayer(Func func) const
    {
        for (auto it = m_Layers.begin(); it != m_Layers.end(); )
        {
             auto next = std::next(it);
             func(*it);
             it = next;
        }
    }

    using SubgraphViewPtr = std::unique_ptr<SubgraphView>;
    using InputSlots = std::vector<InputSlot*>;
    using OutputSlots = std::vector<OutputSlot*>;
    using Layers = std::list<Layer*>;
    using Iterator = Layers::iterator;
    using ConstIterator = Layers::const_iterator;

    /// Constructs a sub-graph from the entire given graph.
    explicit SubgraphView(Graph& graph);

    /// Constructs a sub-graph with the given arguments.
    SubgraphView(InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers);

    /// Copy-constructor.
    SubgraphView(const SubgraphView& subgraph);

    /// Move-constructor.
    SubgraphView(SubgraphView&& subgraph);

    /// Constructs a sub-graph with only the given layer.
    SubgraphView(IConnectableLayer* layer);

    /// Move-assignment operator.
    SubgraphView& operator=(SubgraphView&& other);

    const InputSlots& GetInputSlots() const;
    const OutputSlots& GetOutputSlots() const;
    const Layers& GetLayers() const;

    const InputSlot* GetInputSlot(unsigned int index) const;
    InputSlot* GetInputSlot(unsigned int index);

    const OutputSlot* GetOutputSlot(unsigned int index) const;
    OutputSlot* GetOutputSlot(unsigned int index);

    unsigned int GetNumInputSlots() const;
    unsigned int GetNumOutputSlots() const;

    Iterator begin();
    Iterator end();

    ConstIterator begin() const;
    ConstIterator end() const;

    ConstIterator cbegin() const;
    ConstIterator cend() const;

    void Clear();

private:
    void CheckSubgraph();

    /// The list of pointers to the input slots of the parent graph.
    InputSlots m_InputSlots;

    /// The list of pointers to the output slots of the parent graph.
    OutputSlots m_OutputSlots;

    /// The list of pointers to the layers of the parent graph.
    Layers m_Layers;
};

///
/// Old SubGraph definition kept for backward compatibility only.
///
using SubGraph ARMNN_DEPRECATED_MSG("SubGraph is deprecated, use SubgraphView instead") = SubgraphView;

} // namespace armnn