diff options
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 4623461302..482d9277e8 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -246,6 +246,15 @@ private: } return it; } + Iterator ForwardToEndOfInputsAndConstants(Iterator it) const + { + while ((it != m_Layers.end()) && + ((*it)->GetType() == LayerType::Input || (*it)->GetType() == LayerType::Constant)) + { + ++it; + } + return it; + } Iterator RewindToBeginOfOutputs(Iterator it) const { @@ -335,7 +344,7 @@ protected: Graph* m_Graph; }; -/// Input/Output layers specialize this template. +/// Input/Output/Constant layers specialize this template. template <typename LayerT> class Graph::LayerInGraph final : public LayerInGraphBase<LayerT> { @@ -352,12 +361,32 @@ public: LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args) : LayerInGraphBase<LayerT>(graph, // Make sure it's inserted after all inputs and before all outputs. - graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)), + graph.ForwardToEndOfInputsAndConstants(graph.RewindToBeginOfOutputs(insertBefore)), std::forward<Args>(args)...) { } }; +template <> +class Graph::LayerInGraph<ConstantLayer> final : public LayerInGraphBase<ConstantLayer> +{ +public: + template <typename... Args> + LayerInGraph(Graph& graph, Args&&... args) + : LayerInGraphBase<ConstantLayer>(graph, + // Always add to the back of the inputs. + std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())), + std::forward<Args>(args)...) + {} + template <typename... Args> + LayerInGraph(Graph& graph, Iterator, Args&&... args) + // Ignore Iterator argument. Always add to the back of the inputs. + : LayerInGraph(graph, std::forward<Args>(args)...) + {} + ~LayerInGraph() override + {} +}; + /// Inputs add/remove their binding id to m_InputIds in the graph. template <> class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer> |