ArmNN
 20.02
Optimization.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Graph.hpp"
8 #include "LayersFwd.hpp"
9 
10 namespace armnn
11 {
12 
14 {
15 public:
16  Optimization() = default;
17  virtual ~Optimization() = default;
18  virtual void Run(Graph& graph, Layer& base) const = 0;
19 protected:
20 };
21 
22 // Wrappers
23 // The implementation of the following wrappers make use of the CRTP C++ idiom
24 // (curiously recurring template pattern).
25 // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
26 
27 /// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
28 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
29 /// after applying each optimization.
30 template <typename BaseType, typename Wrapped>
31 class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
32 {
33 public:
34  using Wrapped::Wrapped;
35 
36  void Run(Graph& graph, Layer& base) const override
37  {
38  if (base.GetType() == LayerEnumOf<BaseType>())
39  {
40  Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(&base));
41  }
42  }
43 
44 protected:
45  ~OptimizeForTypeImpl() = default;
46 };
47 
48 /// Specialization that calls Wrapped::Run() for any layer type.
49 template <typename Wrapped>
50 class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
51 {
52 public:
53  using Wrapped::Wrapped;
54 
55  void Run(Graph& graph, Layer& base) const override
56  {
57  Wrapped::Run(graph, base);
58  }
59 
60 protected:
61  ~OptimizeForTypeImpl() = default;
62 };
63 
64 template <typename BaseType, typename Wrapped>
65 class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped>
66 {
67 public:
69 };
70 
71 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
72 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
73 /// after applying each optimization.
74 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
75 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
76 template <typename BaseType, typename ChildType, typename Wrapped>
77 class OptimizeForConnectionImpl : public Wrapped
78 {
79 public:
80  using Wrapped::Wrapped;
81 
82  void Run(Graph& graph, BaseType& base) const
83  {
84  for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
85  {
86  for (auto&& childInput : output->GetConnections())
87  {
88  if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
89  {
90  Wrapped::Run(graph, *childInput);
91  }
92  }
93 
94  // Removes unconnected children.
95  for (unsigned int i = 0; i < output->GetNumConnections();)
96  {
97  Layer* child = &output->GetConnection(i)->GetOwningLayer();
98 
99  if (child->IsOutputUnconnected())
100  {
101  graph.EraseLayer(child);
102  }
103  else
104  {
105  ++i;
106  }
107  }
108  }
109  }
110 
111 protected:
112  ~OptimizeForConnectionImpl() = default;
113 };
114 
115 template <typename BaseType, typename ChildType, typename Wrapped>
117  : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>
118 {
119 public:
121 };
122 
123 } // namespace armnn
virtual void Run(Graph &graph, Layer &base) const =0
Optimization()=default
void EraseLayer(Iterator pos)
Deletes the layer at the specified position.
Definition: Graph.hpp:442
Copyright (c) 2020 ARM Limited.
bool IsOutputUnconnected()
Definition: Layer.hpp:243
Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType...
virtual ~Optimization()=default
void Run(Graph &graph, Layer &base) const override
Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType...
void Run(Graph &graph, BaseType &base) const
void Run(Graph &graph, Layer &base) const override
LayerType GetType() const
Definition: Layer.hpp:259