diff options
Diffstat (limited to 'src/armnn/optimizations/OptimizeInversePermutes.hpp')
-rw-r--r-- | src/armnn/optimizations/OptimizeInversePermutes.hpp | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/src/armnn/optimizations/OptimizeInversePermutes.hpp b/src/armnn/optimizations/OptimizeInversePermutes.hpp index 48bfa35440..77d62a50cb 100644 --- a/src/armnn/optimizations/OptimizeInversePermutes.hpp +++ b/src/armnn/optimizations/OptimizeInversePermutes.hpp @@ -13,6 +13,7 @@ namespace armnn namespace optimizations { +template <typename PermuteType> class OptimizeInversePermutesImpl { public: @@ -22,9 +23,9 @@ public: { boost::ignore_unused(graph); Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); - auto child = boost::polymorphic_downcast<PermuteLayer*>(&connection.GetOwningLayer()); + auto child = boost::polymorphic_downcast<PermuteType*>(&connection.GetOwningLayer()); - if (child->IsInverse(*boost::polymorphic_downcast<PermuteLayer*>(&base))) + if (child->IsInverse(*boost::polymorphic_downcast<PermuteType*>(&base))) { // Bypass both layers. Child will be removed as it's left unconnected. // Base layer will be removed if left unconnected. @@ -37,7 +38,10 @@ protected: ~OptimizeInversePermutesImpl() = default; }; -using OptimizeInversePermutes = OptimizeForConnection<PermuteLayer, PermuteLayer, OptimizeInversePermutesImpl>; +using OptimizeInversePermutes = OptimizeForConnection<PermuteLayer, PermuteLayer, + OptimizeInversePermutesImpl<PermuteLayer>>; +using OptimizeInverseTransposes = OptimizeForConnection<TransposeLayer, TransposeLayer, + OptimizeInversePermutesImpl<TransposeLayer>>; } // namespace optimizations } // namespace armnn |