17 #include <unordered_map> 18 #include <unordered_set> 21 #include <boost/assert.hpp> 22 #include <boost/iterator/transform_iterator.hpp> 32 template <
typename LayerType>
35 return boost::polymorphic_downcast<LayerType*>(layer);
38 template <
typename Func>
41 for (
auto it = m_Layers.begin(); it != m_Layers.end(); )
43 auto next = std::next(it);
64 return {
m_Graph.m_Layers.begin(), &(PtrCast<const InputLayer>) };
70 &(PtrCast<const InputLayer>) };
84 &(PtrCast<const OutputLayer>) };
89 return {
m_Graph.m_Layers.end(), &(PtrCast<const OutputLayer>) };
103 *
this = std::move(other);
108 m_InputIds = std::move(other.m_InputIds);
109 m_OutputIds = std::move(other.m_OutputIds);
110 m_LayersInOrder = std::move(other.m_LayersInOrder);
111 m_Views = std::move(other.m_Views);
115 otherLayer->
Reparent(*
this, m_Layers.end());
118 BOOST_ASSERT(other.m_PosInGraphMap.empty());
119 BOOST_ASSERT(other.m_Layers.empty());
137 template <
typename LayerT,
typename... Args>
142 template <
typename LayerT,
typename... Args>
146 template <
typename LayerT,
typename... Args>
154 template <
typename LayerT>
205 m_Views[notifyOnEvent].emplace_back(observable);
209 m_Views[notifyOnEvent].remove(observable);
216 template <
typename LayerT>
217 class LayerInGraphBase;
219 template <
typename LayerT>
233 while ((it != m_Layers.begin()) && ((*std::prev(it))->GetType() ==
LayerType::Output))
243 for (
auto& observable : m_Views[event])
245 observable->Update(graphState);
249 std::unordered_set<LayerBindingId> m_InputIds;
250 std::unordered_set<LayerBindingId> m_OutputIds;
251 std::unordered_map<const Layer*, Iterator> m_PosInGraphMap;
259 mutable bool m_LayersInOrder;
261 std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
265 template <
typename LayerT>
266 class Graph::LayerInGraphBase :
public LayerT
269 template <
typename... Args>
270 LayerInGraphBase(
Graph& graph,
Iterator insertBefore, Args&&... args)
271 : LayerT(std::forward<Args>(args)...),
m_Graph(&graph)
273 Insert(*
m_Graph, insertBefore);
280 void Reparent(
Graph& destGraph,
Iterator insertBefore)
override 282 Insert(destGraph, insertBefore);
291 graph.m_PosInGraphMap.emplace(
this, graph.m_Layers.emplace(insertBefore,
this));
297 graph.m_Layers.erase(layerIt);
299 const size_t numErased = graph.m_PosInGraphMap.erase(
this);
300 boost::ignore_unused(numErased);
301 BOOST_ASSERT(numErased == 1);
309 template <
typename LayerT>
310 class Graph::LayerInGraph final :
public LayerInGraphBase<LayerT>
313 template <
typename... Args>
314 LayerInGraph(
Graph& graph, Args&&... args)
315 : LayerInGraphBase<LayerT>(graph,
318 std::forward<Args>(args)...)
321 template <
typename... Args>
322 LayerInGraph(
Graph& graph,
Iterator insertBefore, Args&&... args)
323 : LayerInGraphBase<LayerT>(graph,
325 graph.ForwardToEndOfInputs(graph.RewindToBeginOfOutputs(insertBefore)),
326 std::forward<Args>(args)...)
333 class Graph::LayerInGraph<
InputLayer> final :
public LayerInGraphBase<InputLayer>
336 template <
typename... Args>
341 std::forward<Args>(args)...)
343 const bool isNewId =
m_Graph->m_InputIds.emplace(GetBindingId()).second;
349 template <
typename... Args>
352 : LayerInGraph(graph,
std::forward<Args>(args)...)
357 const size_t numErased =
m_Graph->m_InputIds.erase(GetBindingId());
358 boost::ignore_unused(numErased);
359 BOOST_ASSERT(numErased == 1);
365 class Graph::LayerInGraph<
OutputLayer> final :
public LayerInGraphBase<OutputLayer>
368 template <
typename... Args>
373 std::forward<Args>(args)...)
375 const bool isNewId =
m_Graph->m_OutputIds.emplace(GetBindingId()).second;
383 const size_t numErased =
m_Graph->m_OutputIds.erase(GetBindingId());
384 boost::ignore_unused(numErased);
385 BOOST_ASSERT(numErased == 1);
391 auto it = m_PosInGraphMap.find(&layer);
392 BOOST_ASSERT(it != m_PosInGraphMap.end());
396 template <
typename LayerT,
typename... Args>
399 m_LayersInOrder = m_LayersInOrder &&
401 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, std::forward<Args>(args)...);
408 template <
typename LayerT,
typename... Args>
413 const Iterator pos = (parentOut !=
nullptr)
416 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, pos, std::forward<Args>(args)...);
417 insertBefore.
Insert(*layer);
424 template <
typename LayerT,
typename... Args>
430 LayerT*
const layer =
new LayerInGraph<LayerT>(*
this, pos, std::forward<Args>(args)...);
432 BOOST_ASSERT(layer->GetNumInputSlots() == 1);
435 insertAfter.
Connect(layer->GetInputSlot(0));
449 template <
typename LayerT>
452 BOOST_ASSERT(layer !=
nullptr);
size_t GetNumLayers() const
Graph & TopologicalSort()
Sorts layers in topological order and return this.
ConstIteratorOutputs begin() const
size_t GetNumOutputs() const
void DetachObservable(IGraphObservable *const observable, GraphEvent notifyOnEvent)
Iterator end()
Returns iterator pointing to the end of the list. Lowercase for range-based for loops.
OutputLayersAccessor GetOutputLayers() const
void AttachObservable(IGraphObservable *const observable, GraphEvent notifyOnEvent)
boost::transform_iterator< decltype(&PtrCast< const OutputLayer >), Iterator > ConstIteratorOutputs
boost::transform_iterator< decltype(&PtrCast< const InputLayer >), Iterator > ConstIteratorInputs
LayerT * AddLayer(Args &&... args)
Adds a new layer, of type LayerType, to the graph constructed with the arguments passed.
LayerInGraph(Graph &graph, Args &&... args)
Status SerializeToDot(std::ostream &stream)
DataLayout::NCHW DataLayout::NCHW DataLayout::NHWC DataLayout::NHWC true
int Connect(InputSlot &destination)
Iterator begin()
Returns iterator pointing to the beginning of the list. Lowercase for range-based for loops...
bool Remove(const char *path)
void AddCompatibilityLayers(std::map< BackendId, std::unique_ptr< class IBackendInternal >> &backends, TensorHandleFactoryRegistry ®istry)
Iterator GetPosInGraph(Layer &layer)
Gets the position of a layer in the graph.
Graph & operator=(Graph &&other)
Iterator::difference_type IteratorDifference
void ForEachLayer(Func func) const
A layer user-provided data can be bound to (e.g. inputs, outputs).
static LayerType * PtrCast(Layer *const layer)
boost::transform_iterator< decltype(&PtrCast< const Layer >), Iterator > ConstIterator
void SubstituteSubgraph(SubgraphView &subgraph, IConnectableLayer *substituteLayer)
InputLayersAccessor GetInputLayers() const
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
ConstIterator cbegin() const
Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops...
Wrapper class returned by Graph::GetOutputLayers()
LayerList::const_iterator Iterator
std::list< Layer * > LayerList
virtual void Reparent(Graph &dest, std::list< Layer *>::const_iterator iterator)=0
ConstIterator end() const
Returns const iterator pointing to the end of the list. Lowercase for range-based for loops...
Status AllocateDynamicBuffers()
Allocates memory for all tensors under output tensor handers of each layer.
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Graph & operator=(const Graph &other)=delete
OutputLayersAccessor(const Graph &graph)
ConstIterator cend() const
Returns const iterator pointing to the end of the list. Lowercase for range-based for loops...
void EraseLayer(Iterator pos)
Deletes the layer at the specified position.
size_t GetNumInputs() const
ConstIterator begin() const
Returns const iterator pointing to the beginning of the list. Lowercase for range-based for loops...
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Layer & GetOwningLayer() const
ConstIteratorOutputs end() const