ArmNN
 21.02
Optimizer Class Reference

#include <Optimizer.hpp>

Public Types

using OptimizationPtr = std::unique_ptr< Optimization >
 
using Optimizations = std::vector< OptimizationPtr >
 

Static Public Member Functions

static void Pass (Graph &graph, const Optimizations &optimizations)
 

Detailed Description

Definition at line 14 of file Optimizer.hpp.

Member Typedef Documentation

◆ OptimizationPtr

using OptimizationPtr = std::unique_ptr<Optimization>

Definition at line 17 of file Optimizer.hpp.

◆ Optimizations

using Optimizations = std::vector<OptimizationPtr>

Definition at line 18 of file Optimizer.hpp.

Member Function Documentation

◆ Pass()

void Pass ( Graph graph,
const Optimizations optimizations 
)
static

Definition at line 16 of file Optimizer.cpp.

References ARMNN_ASSERT, Graph::begin(), GraphObservable< ObservedType >::Clear(), Graph::end(), Graph::EraseLayer(), Graph::GetPosInGraph(), and Graph::TopologicalSort().

Referenced by AddBroadcastReshapeLayerOptimizerTest(), BOOST_AUTO_TEST_CASE(), and armnn::Optimize().

17 {
18  // Create observables to observe changes to the graph
19  AddedLayerObservable addedLayerObservable(graph);
20  ErasedLayerNamesObservable erasedLayerNamesObservable(graph);
21 
22  bool graphNeedsSorting = false;
23  auto it = graph.TopologicalSort().end();
24 
25  // Calls TopologicalSort() for every iteration to re-order the list in case layers were added/removed.
26  while (it != graph.TopologicalSort().begin())
27  {
28  --it;
29  for (auto&& optimization : optimizations)
30  {
31  ARMNN_ASSERT(*it);
32  optimization->Run(graph, **it);
33 
34  if ((*it)->IsOutputUnconnected())
35  {
36  auto next = std::next(graph.GetPosInGraph(**it));
37  graph.EraseLayer(it);
38  it = next;
39  graphNeedsSorting = true;
40  }
41 
42  // Add the names of erased layers as related layers to the new added layers
43  for (auto& erasedLayerName : erasedLayerNamesObservable)
44  {
45  for (auto& addedLayer : addedLayerObservable)
46  {
47  addedLayer->AddRelatedLayerName(erasedLayerName);
48  }
49  }
50 
51  erasedLayerNamesObservable.Clear();
52  addedLayerObservable.Clear();
53 
54  if (graphNeedsSorting)
55  {
56  graphNeedsSorting = false;
57  break;
58  }
59  }
60  }
61 }
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14

The documentation for this class was generated from the following files: