ArmNN
 21.08
Optimization.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Graph.hpp"
8 #include "LayersFwd.hpp"
9 
11 
12 namespace armnn
13 {
14 
16 {
17 public:
18  Optimization() = default;
19  virtual ~Optimization() = default;
20  virtual void Run(Graph& graph, Layer& base) const = 0;
21 protected:
22 };
23 
24 // Wrappers
25 // The implementation of the following wrappers make use of the CRTP C++ idiom
26 // (curiously recurring template pattern).
27 // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
28 
29 /// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
30 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
31 /// after applying each optimization.
32 template <typename BaseType, typename Wrapped>
33 class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
34 {
35 public:
36  using Wrapped::Wrapped;
37 
38  void Run(Graph& graph, Layer& base) const override
39  {
40  if (base.GetType() == LayerEnumOf<BaseType>())
41  {
42  Wrapped::Run(graph, *PolymorphicDowncast<BaseType*>(&base));
43  }
44  }
45 
46 protected:
47  ~OptimizeForTypeImpl() = default;
48 };
49 
50 /// Specialization that calls Wrapped::Run() for any layer type.
51 template <typename Wrapped>
52 class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
53 {
54 public:
55  using Wrapped::Wrapped;
56 
57  void Run(Graph& graph, Layer& base) const override
58  {
59  Wrapped::Run(graph, base);
60  }
61 
62 protected:
63  ~OptimizeForTypeImpl() = default;
64 };
65 
66 template <typename BaseType, typename Wrapped>
67 class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped>
68 {
69 public:
71 };
72 
73 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
74 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
75 /// after applying each optimization.
76 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
77 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
78 template <typename BaseType, typename ChildType, typename Wrapped>
79 class OptimizeForConnectionImpl : public Wrapped
80 {
81 public:
82  using Wrapped::Wrapped;
83 
84  void Run(Graph& graph, BaseType& base) const
85  {
86  for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
87  {
88  for (auto&& childInput : output->GetConnections())
89  {
90  if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
91  {
92  Wrapped::Run(graph, *childInput);
93  }
94  }
95 
96  // Removes unconnected children.
97  for (unsigned int i = 0; i < output->GetNumConnections();)
98  {
99  Layer* child = &output->GetConnection(i)->GetOwningLayer();
100 
101  if (child->IsOutputUnconnected())
102  {
103  graph.EraseLayer(child);
104  }
105  else
106  {
107  ++i;
108  }
109  }
110  }
111  }
112 
113 protected:
114  ~OptimizeForConnectionImpl() = default;
115 };
116 
117 template <typename BaseType, typename ChildType, typename Wrapped>
119  : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>
120 {
121 public:
123 };
124 
125 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
126 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
127 /// after applying each optimization.
128 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
129 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
130 template <typename BaseType, typename ChildType, typename Wrapped>
132 {
133 public:
134  using Wrapped::Wrapped;
135 
136  void Run(Graph& graph, BaseType& base) const
137  {
138  for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
139  {
140  if (output->GetNumConnections() == 1)
141  {
142  for (auto&& childInput : output->GetConnections())
143  {
144  if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
145  {
146  Wrapped::Run(graph, *childInput);
147  }
148  }
149 
150  // Removes unconnected children.
151  for (unsigned int i = 0; i < output->GetNumConnections();)
152  {
153  Layer* child = &output->GetConnection(i)->GetOwningLayer();
154 
155  if (child->IsOutputUnconnected())
156  {
157  graph.EraseLayer(child);
158  }
159  else
160  {
161  ++i;
162  }
163  }
164  }
165  }
166  }
167 
168 protected:
170 };
171 
172 template <typename BaseType, typename ChildType, typename Wrapped>
174  : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>
175 {
176 public:
177  using OptimizeForTypeImpl<BaseType,
179 };
180 
181 } // namespace armnn
virtual void Run(Graph &graph, Layer &base) const =0
void Run(Graph &graph, BaseType &base) const
Optimization()=default
void EraseLayer(Iterator pos)
Deletes the layer at the specified position.
Definition: Graph.hpp:449
Copyright (c) 2021 ARM Limited and Contributors.
bool IsOutputUnconnected()
Definition: Layer.hpp:249
Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType...
virtual ~Optimization()=default
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:265
void Run(Graph &graph, Layer &base) const override
Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType...
void Run(Graph &graph, BaseType &base) const
void Run(Graph &graph, Layer &base) const override
Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType...