aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/Optimization.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/Optimization.hpp')
-rw-r--r--src/armnn/optimizations/Optimization.hpp27
1 files changed, 13 insertions, 14 deletions
diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp
index 89e03ff88d..f81071891b 100644
--- a/src/armnn/optimizations/Optimization.hpp
+++ b/src/armnn/optimizations/Optimization.hpp
@@ -13,7 +13,7 @@ namespace armnn
class Optimization
{
public:
- virtual void Run(Graph& graph, Graph::Iterator& pos) const = 0;
+ virtual void Run(Graph& graph, Layer& base) const = 0;
protected:
~Optimization() = default;
};
@@ -23,22 +23,20 @@ protected:
// (curiously recurring template pattern).
// For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
-/// Wrapper Optimization base class that calls Wrapped::Run for every layer of type BaseType.
-/// - Wrapped class mustn't remove the base layer.
-/// - Base layer is removed if left unconnected after applying the wrapped optimization.
+/// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
+/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
+/// after applying each optimization.
template <typename BaseType, typename Wrapped>
class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
{
public:
using Wrapped::Wrapped;
- void Run(Graph& graph, Graph::Iterator& pos) const override
+ void Run(Graph& graph, Layer& base) const override
{
- Layer* const base = *pos;
-
- if (base->GetType() == LayerEnumOf<BaseType>())
+ if (base.GetType() == LayerEnumOf<BaseType>())
{
- Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(base));
+ Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(&base));
}
}
@@ -46,16 +44,16 @@ protected:
~OptimizeForTypeImpl() = default;
};
-/// Specialization that calls Wrapped::Run for any layer type
+/// Specialization that calls Wrapped::Run() for any layer type
template <typename Wrapped>
class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
{
public:
using Wrapped::Wrapped;
- void Run(Graph& graph, Graph::Iterator& pos) const override
+ void Run(Graph& graph, Layer& base) const override
{
- Wrapped::Run(graph, **pos);
+ Wrapped::Run(graph, base);
}
protected:
@@ -70,9 +68,10 @@ public:
};
/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
-/// - Wrapped class mustn't remove the base layer.
+/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
+/// after applying each optimization.
/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
-/// - Base and children layers are removed if left unconnected after applying the wrapped optimization.
+/// - Children layers are removed if left unconnected after applying the wrapped optimization.
template <typename BaseType, typename ChildType, typename Wrapped>
class OptimizeForConnectionImpl : public Wrapped
{