diff options
Diffstat (limited to 'src/armnn/optimizations/Optimization.hpp')
-rw-r--r-- | src/armnn/optimizations/Optimization.hpp | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp index 1796ac842b..320cae2b75 100644 --- a/src/armnn/optimizations/Optimization.hpp +++ b/src/armnn/optimizations/Optimization.hpp @@ -122,4 +122,60 @@ public: using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl; }; +/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. +/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected +/// after applying each optimization. +/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. +/// - Children layers are removed if left unconnected after applying the wrapped optimization. +template <typename BaseType, typename ChildType, typename Wrapped> +class OptimizeForExclusiveConnectionImpl : public Wrapped +{ +public: + using Wrapped::Wrapped; + + void Run(Graph& graph, BaseType& base) const + { + for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output) + { + if (output->GetNumConnections() == 1) + { + for (auto&& childInput : output->GetConnections()) + { + if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>()) + { + Wrapped::Run(graph, *childInput); + } + } + + // Removes unconnected children. + for (unsigned int i = 0; i < output->GetNumConnections();) + { + Layer* child = &output->GetConnection(i)->GetOwningLayer(); + + if (child->IsOutputUnconnected()) + { + graph.EraseLayer(child); + } + else + { + ++i; + } + } + } + } + } + +protected: + ~OptimizeForExclusiveConnectionImpl() = default; +}; + +template <typename BaseType, typename ChildType, typename Wrapped> +class OptimizeForExclusiveConnection final + : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>> +{ +public: + using OptimizeForTypeImpl<BaseType, + OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl; +}; + } // namespace armnn |