diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-10-15 13:16:07 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-10-29 19:15:01 +0000 |
commit | 06e0300ccf279c6b0fcbb5ef3b6fa36e00229492 (patch) | |
tree | cea4eec69904c40a326b3e4c043c88e441b77b7a /src/armnn/optimizations/Optimization.hpp | |
parent | 34515a1897410adc08390888a6643db390a53d05 (diff) | |
download | armnn-06e0300ccf279c6b0fcbb5ef3b6fa36e00229492.tar.gz |
IVGCVSW-5314 Create OptimizeForExclusiveConnection
* FuseBatchNorm class has been added to facilitate testing
* Only Convolution2D FP32 being fused
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I049c4770946ddca21b08516d4c9f4d0d22bf9b45
Diffstat (limited to 'src/armnn/optimizations/Optimization.hpp')
-rw-r--r-- | src/armnn/optimizations/Optimization.hpp | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp index 1796ac842b..320cae2b75 100644 --- a/src/armnn/optimizations/Optimization.hpp +++ b/src/armnn/optimizations/Optimization.hpp @@ -122,4 +122,60 @@ public: using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl; }; +/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. +/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected +/// after applying each optimization. +/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. +/// - Children layers are removed if left unconnected after applying the wrapped optimization. +template <typename BaseType, typename ChildType, typename Wrapped> +class OptimizeForExclusiveConnectionImpl : public Wrapped +{ +public: + using Wrapped::Wrapped; + + void Run(Graph& graph, BaseType& base) const + { + for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output) + { + if (output->GetNumConnections() == 1) + { + for (auto&& childInput : output->GetConnections()) + { + if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>()) + { + Wrapped::Run(graph, *childInput); + } + } + + // Removes unconnected children. + for (unsigned int i = 0; i < output->GetNumConnections();) + { + Layer* child = &output->GetConnection(i)->GetOwningLayer(); + + if (child->IsOutputUnconnected()) + { + graph.EraseLayer(child); + } + else + { + ++i; + } + } + } + } + } + +protected: + ~OptimizeForExclusiveConnectionImpl() = default; +}; + +template <typename BaseType, typename ChildType, typename Wrapped> +class OptimizeForExclusiveConnection final + : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>> +{ +public: + using OptimizeForTypeImpl<BaseType, + OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl; +}; + } // namespace armnn |