diff options
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 100 |
1 files changed, 62 insertions, 38 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 8888034197..34aefbf085 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -92,6 +92,8 @@ public: Status Print() const; + Status SerializeToDot(std::ostream& stream); + /// Adds a new layer of type LaterType to the graph constructed with the arguments passed. template <typename LayerT, typename... Args> LayerT* AddLayer(Args&&... args); @@ -121,6 +123,11 @@ public: /// Return const iterator pointing to end of list. Lowercase for range-based for loops. ConstIterator end() const { return {m_Layers.end(), &PtrCast<const Layer>}; } + /// Return const iterator pointing to begin of list. Lowercase for range-based for loops. + ConstIterator cbegin() const { return begin(); } + /// Return const iterator pointing to end of list. Lowercase for range-based for loops. + ConstIterator cend() const { return end(); } + /// Sort layers in topological order and return this. Graph& TopologicalSort() { const_cast<const Graph*>(this)->TopologicalSort(); return *this; } const Graph& TopologicalSort() const; @@ -154,13 +161,27 @@ private: template <typename LayerT> class LayerInGraph; + Iterator ForwardToEndOfInputs(Iterator it) const + { + while ((it != m_Layers.end()) && ((*it)->GetType() == LayerType::Input)) + { + ++it; + } + return it; + } + + Iterator RewindToBeginOfOutputs(Iterator it) const + { + while ((it != m_Layers.begin()) && ((*std::prev(it))->GetType() == LayerType::Output)) + { + --it; + } + return it; + } + /// Get the position of a layer in the graph. Iterator GetPosInGraph(Layer& layer); - /// Adds a new layer of type LaterType to the graph constructed with the arguments passed. - template <typename LayerT, typename... Args> - LayerInGraph<LayerT>* AddLayerImpl(Iterator insertBefore, Args&&... args); - std::unordered_set<LayerBindingId> m_InputIds; std::unordered_set<LayerBindingId> m_OutputIds; std::unordered_map<const Layer*, Iterator> m_PosInGraphMap; @@ -197,8 +218,19 @@ class Graph::LayerInGraph final : public LayerInGraphBase<LayerT> { public: template <typename... Args> + LayerInGraph(Graph& graph, Args&&... args) + : LayerInGraphBase<LayerT>(graph, + // Insert at the back of the intermediate layers (before outputs). + std::prev(graph.end(), IteratorDifference(graph.GetNumOutputs())), + std::forward<Args>(args)...) + { + } + template <typename... Args> LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args) - : LayerInGraphBase<LayerT>(graph, insertBefore, std::forward<Args>(args)...) + : LayerInGraphBase<LayerT>(graph, + // Make sure it's inserted after all inputs and before all outputs. + graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)), + std::forward<Args>(args)...) { } }; @@ -209,8 +241,11 @@ class Graph::LayerInGraph<InputLayer> final : public LayerInGraphBase<InputLayer { public: template <typename... Args> - LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args) - : LayerInGraphBase<InputLayer>(graph, insertBefore, std::forward<Args>(args)...) + LayerInGraph(Graph& graph, Args&&... args) + : LayerInGraphBase<InputLayer>(graph, + // Always add to the back of the inputs. + std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())), + std::forward<Args>(args)...) { const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second; if (!isNewId) @@ -218,6 +253,12 @@ public: throw InvalidArgumentException("A layer already exists with the specified id"); } } + template <typename... Args> + LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args) + // Ignore insertBefore. Always add to the back of the inputs. + : LayerInGraph(graph, std::forward<Args>(args)...) + { + } ~LayerInGraph() override { const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId()); @@ -232,8 +273,11 @@ class Graph::LayerInGraph<OutputLayer> final : public LayerInGraphBase<OutputLay { public: template <typename... Args> - LayerInGraph(Graph& graph, Iterator insertBefore, Args&&... args) - : LayerInGraphBase<OutputLayer>(graph, insertBefore, std::forward<Args>(args)...) + LayerInGraph(Graph& graph, Args&&... args) + : LayerInGraphBase<OutputLayer>(graph, + // Always add to the back of the outputs. + graph.end(), + std::forward<Args>(args)...) { const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second; if (!isNewId) @@ -257,42 +301,22 @@ inline Graph::Iterator Graph::GetPosInGraph(Layer& layer) } template <typename LayerT, typename... Args> -inline Graph::LayerInGraph<LayerT>* Graph::AddLayerImpl(Iterator insertBefore, Args&&... args) -{ - return new LayerInGraph<LayerT>(*this, insertBefore, std::forward<Args>(args)...); -} - -/// Inputs are inserted at the front of the list, to keep the order correct if the list is sorted. -/// Outputs are inserted at the back of the list, to keep the order correct if the list is sorted. -/// Other layers are inserted before existing outputs, so the latter remain at the back of the list. -template <typename LayerT, typename... Args> inline LayerT* Graph::AddLayer(Args&&... args) { - switch (LayerEnumOf<LayerT>()) - { - case LayerType::Input: - { - return AddLayerImpl<LayerT>(begin(), std::forward<Args>(args)...); - } - case LayerType::Output: - { - return AddLayerImpl<LayerT>(end(), std::forward<Args>(args)...); - } - default: - { - m_LayersInOrder = false; - const auto pos = std::prev(end(), IteratorDifference(GetNumOutputs())); - return AddLayerImpl<LayerT>(pos, std::forward<Args>(args)...); - } - } + m_LayersInOrder = m_LayersInOrder && + ((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output)); + return new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...); } template <typename LayerT, typename... Args> inline LayerT* Graph::InsertNewLayer(InputSlot& insertBefore, Args&&... args) { - // Insert before the child layer so topological order is kept. - const Iterator pos = GetPosInGraph(insertBefore.GetOwningLayer()); - LayerT* const layer = AddLayerImpl<LayerT>(pos, std::forward<Args>(args)...); + // Insert after the parent if any, or before the child otherwise, so topological order is kept. + OutputSlot* parentOut = insertBefore.GetConnectedOutputSlot(); + const Iterator pos = (parentOut != nullptr) + ? std::next(GetPosInGraph(parentOut->GetOwningLayer())) + : GetPosInGraph(insertBefore.GetOwningLayer()); + LayerT* const layer = new LayerInGraph<LayerT>(*this, pos, std::forward<Args>(args)...); insertBefore.Insert(*layer); return layer; } |