From 06e0300ccf279c6b0fcbb5ef3b6fa36e00229492 Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Thu, 15 Oct 2020 13:16:07 +0100 Subject: IVGCVSW-5314 Create OptimizeForExclusiveConnection * FuseBatchNorm class has been added to facilitate testing * Only Convolution2D FP32 being fused Signed-off-by: Teresa Charlin Change-Id: I049c4770946ddca21b08516d4c9f4d0d22bf9b45 --- src/armnn/optimizations/Optimization.hpp | 56 ++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) (limited to 'src/armnn/optimizations/Optimization.hpp') diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp index 1796ac842b..320cae2b75 100644 --- a/src/armnn/optimizations/Optimization.hpp +++ b/src/armnn/optimizations/Optimization.hpp @@ -122,4 +122,60 @@ public: using OptimizeForTypeImpl>::OptimizeForTypeImpl; }; +/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. +/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected +/// after applying each optimization. +/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. +/// - Children layers are removed if left unconnected after applying the wrapped optimization. +template +class OptimizeForExclusiveConnectionImpl : public Wrapped +{ +public: + using Wrapped::Wrapped; + + void Run(Graph& graph, BaseType& base) const + { + for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output) + { + if (output->GetNumConnections() == 1) + { + for (auto&& childInput : output->GetConnections()) + { + if (childInput->GetOwningLayer().GetType() == LayerEnumOf()) + { + Wrapped::Run(graph, *childInput); + } + } + + // Removes unconnected children. + for (unsigned int i = 0; i < output->GetNumConnections();) + { + Layer* child = &output->GetConnection(i)->GetOwningLayer(); + + if (child->IsOutputUnconnected()) + { + graph.EraseLayer(child); + } + else + { + ++i; + } + } + } + } + } + +protected: + ~OptimizeForExclusiveConnectionImpl() = default; +}; + +template +class OptimizeForExclusiveConnection final + : public OptimizeForTypeImpl> +{ +public: + using OptimizeForTypeImpl>::OptimizeForTypeImpl; +}; + } // namespace armnn -- cgit v1.2.1