aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/Optimization.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/Optimization.hpp')
-rw-r--r--src/armnn/optimizations/Optimization.hpp56
1 files changed, 56 insertions, 0 deletions
diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp
index 1796ac842b..320cae2b75 100644
--- a/src/armnn/optimizations/Optimization.hpp
+++ b/src/armnn/optimizations/Optimization.hpp
@@ -122,4 +122,60 @@ public:
using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;
};
+/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
+/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
+/// after applying each optimization.
+/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
+/// - Children layers are removed if left unconnected after applying the wrapped optimization.
+template <typename BaseType, typename ChildType, typename Wrapped>
+class OptimizeForExclusiveConnectionImpl : public Wrapped
+{
+public:
+ using Wrapped::Wrapped;
+
+ void Run(Graph& graph, BaseType& base) const
+ {
+ for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
+ {
+ if (output->GetNumConnections() == 1)
+ {
+ for (auto&& childInput : output->GetConnections())
+ {
+ if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
+ {
+ Wrapped::Run(graph, *childInput);
+ }
+ }
+
+ // Removes unconnected children.
+ for (unsigned int i = 0; i < output->GetNumConnections();)
+ {
+ Layer* child = &output->GetConnection(i)->GetOwningLayer();
+
+ if (child->IsOutputUnconnected())
+ {
+ graph.EraseLayer(child);
+ }
+ else
+ {
+ ++i;
+ }
+ }
+ }
+ }
+ }
+
+protected:
+ ~OptimizeForExclusiveConnectionImpl() = default;
+};
+
+template <typename BaseType, typename ChildType, typename Wrapped>
+class OptimizeForExclusiveConnection final
+ : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>
+{
+public:
+ using OptimizeForTypeImpl<BaseType,
+ OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;
+};
+
} // namespace armnn