aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.hpp
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2022-06-28 16:52:18 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2022-06-29 10:42:49 +0000
commit03ee5d8a21688555c4e0a68d8400f4c3e3d844e2 (patch)
treeaf30388c80cfa9002980bc41de403e9b4a52f7a2 /src/armnn/Graph.hpp
parenta96489a2fd459bd3d73297fa5fdaef5d13a57a4e (diff)
downloadarmnn-03ee5d8a21688555c4e0a68d8400f4c3e3d844e2.tar.gz
IVGCVSW-6962 Adding Const layer in the graph immediately after Input
instead of immediately before output Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I2d89a1efdabfdb4be24a8998a03fe1f502d26183
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r--src/armnn/Graph.hpp33
1 files changed, 31 insertions, 2 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 4623461302..482d9277e8 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -246,6 +246,15 @@ private:
}
return it;
}
+ Iterator ForwardToEndOfInputsAndConstants(Iterator it) const
+ {
+ while ((it != m_Layers.end()) &&
+ ((*it)->GetType() == LayerType::Input || (*it)->GetType() == LayerType::Constant))
+ {
+ ++it;
+ }
+ return it;
+ }
Iterator RewindToBeginOfOutputs(Iterator it) const
{
@@ -335,7 +344,7 @@ protected:
Graph* m_Graph;
};
-/// Input/Output layers specialize this template.
+/// Input/Output/Constant layers specialize this template.
template <typename LayerT>
class Graph::LayerInGraph final : public LayerInGraphBase<LayerT>
{
@@ -352,12 +361,32 @@ public:
LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args)
: LayerInGraphBase<LayerT>(graph,
// Make sure it's inserted after all inputs and before all outputs.
- graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)),
+ graph.ForwardToEndOfInputsAndConstants(graph.RewindToBeginOfOutputs(insertBefore)),
std::forward<Args>(args)...)
{
}
};
+template <>
+class Graph::LayerInGraph<ConstantLayer> final : public LayerInGraphBase<ConstantLayer>
+{
+public:
+ template <typename... Args>
+ LayerInGraph(Graph& graph, Args&&... args)
+ : LayerInGraphBase<ConstantLayer>(graph,
+ // Always add to the back of the inputs.
+ std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
+ std::forward<Args>(args)...)
+ {}
+ template <typename... Args>
+ LayerInGraph(Graph& graph, Iterator, Args&&... args)
+ // Ignore Iterator argument. Always add to the back of the inputs.
+ : LayerInGraph(graph, std::forward<Args>(args)...)
+ {}
+ ~LayerInGraph() override
+ {}
+};
+
/// Inputs add/remove their binding id to m_InputIds in the graph.
template <>
class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer>