From bceff2fb3fc68bb0aa88b886900c34b77340c826 Mon Sep 17 00:00:00 2001 From: surmeh01 Date: Thu, 29 Mar 2018 16:29:27 +0100 Subject: Release 18.03 --- src/armnn/optimizations/SquashEqualSiblings.hpp | 28 +++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) (limited to 'src/armnn/optimizations/SquashEqualSiblings.hpp') diff --git a/src/armnn/optimizations/SquashEqualSiblings.hpp b/src/armnn/optimizations/SquashEqualSiblings.hpp index 2dfe91fdcc..c5ce28e723 100644 --- a/src/armnn/optimizations/SquashEqualSiblings.hpp +++ b/src/armnn/optimizations/SquashEqualSiblings.hpp @@ -26,19 +26,29 @@ public: if (!child.IsOutputUnconnected()) { OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); - auto& comparableChild = *boost::polymorphic_downcast(&child); - for (auto&& it : baseOutput.GetConnections()) + if (baseOutput.GetNumConnections() > 1) { - Layer& sibling = it->GetOwningLayer(); - if ((&sibling != &child) && comparableChild.IsEqual(sibling)) + auto& comparableChild = *boost::polymorphic_downcast(&child); + + Layer* lowestPriorityChild = &child; + for (auto&& it : baseOutput.GetConnections()) { - // Bypass sibling. It will be removed as it's left unconnected. - auto siblingOut = sibling.BeginOutputSlots(); - for (auto childOut = child.BeginOutputSlots(); childOut != child.EndOutputSlots(); ++childOut) + Layer* sibling = &it->GetOwningLayer(); + if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling)) { - siblingOut->MoveAllConnections(*childOut); - ++siblingOut; + if (sibling->GetPriority() < lowestPriorityChild->GetPriority()) + { + std::swap(sibling, lowestPriorityChild); + } + // Bypass sibling. It will be removed as it's left unconnected. + auto siblingOut = sibling->BeginOutputSlots(); + for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots(); + lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut) + { + siblingOut->MoveAllConnections(*lowestPriorityChildOut); + ++siblingOut; + } } } } -- cgit v1.2.1