aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/Optimization.hpp
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-10-15 13:16:07 +0100
committerJim Flynn <jim.flynn@arm.com>2020-10-29 19:15:01 +0000
commit06e0300ccf279c6b0fcbb5ef3b6fa36e00229492 (patch)
treecea4eec69904c40a326b3e4c043c88e441b77b7a /src/armnn/optimizations/Optimization.hpp
parent34515a1897410adc08390888a6643db390a53d05 (diff)
downloadarmnn-06e0300ccf279c6b0fcbb5ef3b6fa36e00229492.tar.gz
IVGCVSW-5314 Create OptimizeForExclusiveConnection
* FuseBatchNorm class has been added to facilitate testing * Only Convolution2D FP32 being fused Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I049c4770946ddca21b08516d4c9f4d0d22bf9b45
Diffstat (limited to 'src/armnn/optimizations/Optimization.hpp')
-rw-r--r--src/armnn/optimizations/Optimization.hpp56
1 files changed, 56 insertions, 0 deletions
diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp
index 1796ac842b..320cae2b75 100644
--- a/src/armnn/optimizations/Optimization.hpp
+++ b/src/armnn/optimizations/Optimization.hpp
@@ -122,4 +122,60 @@ public:
using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;
};
+/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
+/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
+/// after applying each optimization.
+/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
+/// - Children layers are removed if left unconnected after applying the wrapped optimization.
+template <typename BaseType, typename ChildType, typename Wrapped>
+class OptimizeForExclusiveConnectionImpl : public Wrapped
+{
+public:
+ using Wrapped::Wrapped;
+
+ void Run(Graph& graph, BaseType& base) const
+ {
+ for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
+ {
+ if (output->GetNumConnections() == 1)
+ {
+ for (auto&& childInput : output->GetConnections())
+ {
+ if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
+ {
+ Wrapped::Run(graph, *childInput);
+ }
+ }
+
+ // Removes unconnected children.
+ for (unsigned int i = 0; i < output->GetNumConnections();)
+ {
+ Layer* child = &output->GetConnection(i)->GetOwningLayer();
+
+ if (child->IsOutputUnconnected())
+ {
+ graph.EraseLayer(child);
+ }
+ else
+ {
+ ++i;
+ }
+ }
+ }
+ }
+ }
+
+protected:
+ ~OptimizeForExclusiveConnectionImpl() = default;
+};
+
+template <typename BaseType, typename ChildType, typename Wrapped>
+class OptimizeForExclusiveConnection final
+ : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>
+{
+public:
+ using OptimizeForTypeImpl<BaseType,
+ OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;
+};
+
} // namespace armnn