aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/OptimizeInversePermutes.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/OptimizeInversePermutes.hpp')
-rw-r--r--src/armnn/optimizations/OptimizeInversePermutes.hpp40
1 files changed, 40 insertions, 0 deletions
diff --git a/src/armnn/optimizations/OptimizeInversePermutes.hpp b/src/armnn/optimizations/OptimizeInversePermutes.hpp
new file mode 100644
index 0000000000..63820cb7d3
--- /dev/null
+++ b/src/armnn/optimizations/OptimizeInversePermutes.hpp
@@ -0,0 +1,40 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "Optimization.hpp"
+
+namespace armnn
+{
+namespace optimizations
+{
+
+class OptimizeInversePermutesImpl
+{
+public:
+ /// Run for every connection between a base PermuteLayer and a child PermuteLayer.
+ /// Bypasses both layers for that connection if one is the inverse of the other.
+ void Run(Graph& graph, InputSlot& connection) const
+ {
+ Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
+ auto child = boost::polymorphic_downcast<PermuteLayer*>(&connection.GetOwningLayer());
+
+ if (child->IsInverse(*boost::polymorphic_downcast<PermuteLayer*>(&base)))
+ {
+ // Bypass both layers. Child will be removed as it's left unconnected.
+ // Base layer will be removed if left unconnected.
+ child->GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot());
+ }
+ }
+
+protected:
+ OptimizeInversePermutesImpl() = default;
+ ~OptimizeInversePermutesImpl() = default;
+};
+
+using OptimizeInversePermutes = OptimizeForConnection<PermuteLayer, PermuteLayer, OptimizeInversePermutesImpl>;
+
+} // namespace optimizations
+} // namespace armnn