From bceff2fb3fc68bb0aa88b886900c34b77340c826 Mon Sep 17 00:00:00 2001 From: surmeh01 Date: Thu, 29 Mar 2018 16:29:27 +0100 Subject: Release 18.03 --- src/armnn/optimizations/Optimization.hpp | 27 ++++++++++----------- .../optimizations/OptimizeConsecutiveReshapes.hpp | 4 ++-- src/armnn/optimizations/SquashEqualSiblings.hpp | 28 +++++++++++++++------- 3 files changed, 34 insertions(+), 25 deletions(-) (limited to 'src/armnn/optimizations') diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp index 89e03ff88d..f81071891b 100644 --- a/src/armnn/optimizations/Optimization.hpp +++ b/src/armnn/optimizations/Optimization.hpp @@ -13,7 +13,7 @@ namespace armnn class Optimization { public: - virtual void Run(Graph& graph, Graph::Iterator& pos) const = 0; + virtual void Run(Graph& graph, Layer& base) const = 0; protected: ~Optimization() = default; }; @@ -23,22 +23,20 @@ protected: // (curiously recurring template pattern). // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern -/// Wrapper Optimization base class that calls Wrapped::Run for every layer of type BaseType. -/// - Wrapped class mustn't remove the base layer. -/// - Base layer is removed if left unconnected after applying the wrapped optimization. +/// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType. +/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected +/// after applying each optimization. template class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped { public: using Wrapped::Wrapped; - void Run(Graph& graph, Graph::Iterator& pos) const override + void Run(Graph& graph, Layer& base) const override { - Layer* const base = *pos; - - if (base->GetType() == LayerEnumOf()) + if (base.GetType() == LayerEnumOf()) { - Wrapped::Run(graph, *boost::polymorphic_downcast(base)); + Wrapped::Run(graph, *boost::polymorphic_downcast(&base)); } } @@ -46,16 +44,16 @@ protected: ~OptimizeForTypeImpl() = default; }; -/// Specialization that calls Wrapped::Run for any layer type +/// Specialization that calls Wrapped::Run() for any layer type template class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped { public: using Wrapped::Wrapped; - void Run(Graph& graph, Graph::Iterator& pos) const override + void Run(Graph& graph, Layer& base) const override { - Wrapped::Run(graph, **pos); + Wrapped::Run(graph, base); } protected: @@ -70,9 +68,10 @@ public: }; /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. -/// - Wrapped class mustn't remove the base layer. +/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected +/// after applying each optimization. /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. -/// - Base and children layers are removed if left unconnected after applying the wrapped optimization. +/// - Children layers are removed if left unconnected after applying the wrapped optimization. template class OptimizeForConnectionImpl : public Wrapped { diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp index deb49c6884..9a926a57a4 100644 --- a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp +++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp @@ -18,8 +18,8 @@ public: /// Inserts an equivalent ReshapeLayer that bypasses both for that connection. void Run(Graph& graph, InputSlot& connection) const { - auto& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); - auto& child = connection.GetOwningLayer(); + Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); + Layer& child = connection.GetOwningLayer(); BOOST_ASSERT(base.GetType() == LayerType::Reshape); BOOST_ASSERT(child.GetType() == LayerType::Reshape); diff --git a/src/armnn/optimizations/SquashEqualSiblings.hpp b/src/armnn/optimizations/SquashEqualSiblings.hpp index 2dfe91fdcc..c5ce28e723 100644 --- a/src/armnn/optimizations/SquashEqualSiblings.hpp +++ b/src/armnn/optimizations/SquashEqualSiblings.hpp @@ -26,19 +26,29 @@ public: if (!child.IsOutputUnconnected()) { OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); - auto& comparableChild = *boost::polymorphic_downcast(&child); - for (auto&& it : baseOutput.GetConnections()) + if (baseOutput.GetNumConnections() > 1) { - Layer& sibling = it->GetOwningLayer(); - if ((&sibling != &child) && comparableChild.IsEqual(sibling)) + auto& comparableChild = *boost::polymorphic_downcast(&child); + + Layer* lowestPriorityChild = &child; + for (auto&& it : baseOutput.GetConnections()) { - // Bypass sibling. It will be removed as it's left unconnected. - auto siblingOut = sibling.BeginOutputSlots(); - for (auto childOut = child.BeginOutputSlots(); childOut != child.EndOutputSlots(); ++childOut) + Layer* sibling = &it->GetOwningLayer(); + if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling)) { - siblingOut->MoveAllConnections(*childOut); - ++siblingOut; + if (sibling->GetPriority() < lowestPriorityChild->GetPriority()) + { + std::swap(sibling, lowestPriorityChild); + } + // Bypass sibling. It will be removed as it's left unconnected. + auto siblingOut = sibling->BeginOutputSlots(); + for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots(); + lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut) + { + siblingOut->MoveAllConnections(*lowestPriorityChildOut); + ++siblingOut; + } } } } -- cgit v1.2.1