diff options
author | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
---|---|---|
committer | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
commit | bceff2fb3fc68bb0aa88b886900c34b77340c826 (patch) | |
tree | d867d3e090d58d3012dfbbac456e9ea8c7f789bc /src/armnn/optimizations/Optimization.hpp | |
parent | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (diff) | |
download | armnn-bceff2fb3fc68bb0aa88b886900c34b77340c826.tar.gz |
Release 18.03
Diffstat (limited to 'src/armnn/optimizations/Optimization.hpp')
-rw-r--r-- | src/armnn/optimizations/Optimization.hpp | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp index 89e03ff88d..f81071891b 100644 --- a/src/armnn/optimizations/Optimization.hpp +++ b/src/armnn/optimizations/Optimization.hpp @@ -13,7 +13,7 @@ namespace armnn class Optimization { public: - virtual void Run(Graph& graph, Graph::Iterator& pos) const = 0; + virtual void Run(Graph& graph, Layer& base) const = 0; protected: ~Optimization() = default; }; @@ -23,22 +23,20 @@ protected: // (curiously recurring template pattern). // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern -/// Wrapper Optimization base class that calls Wrapped::Run for every layer of type BaseType. -/// - Wrapped class mustn't remove the base layer. -/// - Base layer is removed if left unconnected after applying the wrapped optimization. +/// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType. +/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected +/// after applying each optimization. template <typename BaseType, typename Wrapped> class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped { public: using Wrapped::Wrapped; - void Run(Graph& graph, Graph::Iterator& pos) const override + void Run(Graph& graph, Layer& base) const override { - Layer* const base = *pos; - - if (base->GetType() == LayerEnumOf<BaseType>()) + if (base.GetType() == LayerEnumOf<BaseType>()) { - Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(base)); + Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(&base)); } } @@ -46,16 +44,16 @@ protected: ~OptimizeForTypeImpl() = default; }; -/// Specialization that calls Wrapped::Run for any layer type +/// Specialization that calls Wrapped::Run() for any layer type template <typename Wrapped> class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped { public: using Wrapped::Wrapped; - void Run(Graph& graph, Graph::Iterator& pos) const override + void Run(Graph& graph, Layer& base) const override { - Wrapped::Run(graph, **pos); + Wrapped::Run(graph, base); } protected: @@ -70,9 +68,10 @@ public: }; /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. -/// - Wrapped class mustn't remove the base layer. +/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected +/// after applying each optimization. /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. -/// - Base and children layers are removed if left unconnected after applying the wrapped optimization. +/// - Children layers are removed if left unconnected after applying the wrapped optimization. template <typename BaseType, typename ChildType, typename Wrapped> class OptimizeForConnectionImpl : public Wrapped { |