diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-06-28 16:52:18 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2022-06-29 10:42:49 +0000 |
commit | 03ee5d8a21688555c4e0a68d8400f4c3e3d844e2 (patch) | |
tree | af30388c80cfa9002980bc41de403e9b4a52f7a2 /src/armnn/Graph.hpp | |
parent | a96489a2fd459bd3d73297fa5fdaef5d13a57a4e (diff) | |
download | armnn-03ee5d8a21688555c4e0a68d8400f4c3e3d844e2.tar.gz |
IVGCVSW-6962 Adding Const layer in the graph immediately after Input
instead of immediately before output
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I2d89a1efdabfdb4be24a8998a03fe1f502d26183
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 4623461302..482d9277e8 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -246,6 +246,15 @@ private: } return it; } + Iterator ForwardToEndOfInputsAndConstants(Iterator it) const + { + while ((it != m_Layers.end()) && + ((*it)->GetType() == LayerType::Input || (*it)->GetType() == LayerType::Constant)) + { + ++it; + } + return it; + } Iterator RewindToBeginOfOutputs(Iterator it) const { @@ -335,7 +344,7 @@ protected: Graph* m_Graph; }; -/// Input/Output layers specialize this template. +/// Input/Output/Constant layers specialize this template. template <typename LayerT> class Graph::LayerInGraph final : public LayerInGraphBase<LayerT> { @@ -352,12 +361,32 @@ public: LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args) : LayerInGraphBase<LayerT>(graph, // Make sure it's inserted after all inputs and before all outputs. - graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)), + graph.ForwardToEndOfInputsAndConstants(graph.RewindToBeginOfOutputs(insertBefore)), std::forward<Args>(args)...) { } }; +template <> +class Graph::LayerInGraph<ConstantLayer> final : public LayerInGraphBase<ConstantLayer> +{ +public: + template <typename... Args> + LayerInGraph(Graph& graph, Args&&... args) + : LayerInGraphBase<ConstantLayer>(graph, + // Always add to the back of the inputs. + std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())), + std::forward<Args>(args)...) + {} + template <typename... Args> + LayerInGraph(Graph& graph, Iterator, Args&&... args) + // Ignore Iterator argument. Always add to the back of the inputs. + : LayerInGraph(graph, std::forward<Args>(args)...) + {} + ~LayerInGraph() override + {} +}; + /// Inputs add/remove their binding id to m_InputIds in the graph. template <> class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer> |