aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r--src/armnn/Graph.hpp33
1 files changed, 31 insertions, 2 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 4623461302..482d9277e8 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -246,6 +246,15 @@ private:
}
return it;
}
+ Iterator ForwardToEndOfInputsAndConstants(Iterator it) const
+ {
+ while ((it != m_Layers.end()) &&
+ ((*it)->GetType() == LayerType::Input || (*it)->GetType() == LayerType::Constant))
+ {
+ ++it;
+ }
+ return it;
+ }
Iterator RewindToBeginOfOutputs(Iterator it) const
{
@@ -335,7 +344,7 @@ protected:
Graph* m_Graph;
};
-/// Input/Output layers specialize this template.
+/// Input/Output/Constant layers specialize this template.
template <typename LayerT>
class Graph::LayerInGraph final : public LayerInGraphBase<LayerT>
{
@@ -352,12 +361,32 @@ public:
LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args)
: LayerInGraphBase<LayerT>(graph,
// Make sure it's inserted after all inputs and before all outputs.
- graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)),
+ graph.ForwardToEndOfInputsAndConstants(graph.RewindToBeginOfOutputs(insertBefore)),
std::forward<Args>(args)...)
{
}
};
+template <>
+class Graph::LayerInGraph<ConstantLayer> final : public LayerInGraphBase<ConstantLayer>
+{
+public:
+ template <typename... Args>
+ LayerInGraph(Graph& graph, Args&&... args)
+ : LayerInGraphBase<ConstantLayer>(graph,
+ // Always add to the back of the inputs.
+ std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
+ std::forward<Args>(args)...)
+ {}
+ template <typename... Args>
+ LayerInGraph(Graph& graph, Iterator, Args&&... args)
+ // Ignore Iterator argument. Always add to the back of the inputs.
+ : LayerInGraph(graph, std::forward<Args>(args)...)
+ {}
+ ~LayerInGraph() override
+ {}
+};
+
/// Inputs add/remove their binding id to m_InputIds in the graph.
template <>
class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer>