aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/SquashEqualSiblings.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/SquashEqualSiblings.hpp')
-rw-r--r--src/armnn/optimizations/SquashEqualSiblings.hpp28
1 files changed, 19 insertions, 9 deletions
diff --git a/src/armnn/optimizations/SquashEqualSiblings.hpp b/src/armnn/optimizations/SquashEqualSiblings.hpp
index 2dfe91fdcc..c5ce28e723 100644
--- a/src/armnn/optimizations/SquashEqualSiblings.hpp
+++ b/src/armnn/optimizations/SquashEqualSiblings.hpp
@@ -26,19 +26,29 @@ public:
if (!child.IsOutputUnconnected())
{
OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
- auto& comparableChild = *boost::polymorphic_downcast<Comparable*>(&child);
- for (auto&& it : baseOutput.GetConnections())
+ if (baseOutput.GetNumConnections() > 1)
{
- Layer& sibling = it->GetOwningLayer();
- if ((&sibling != &child) && comparableChild.IsEqual(sibling))
+ auto& comparableChild = *boost::polymorphic_downcast<Comparable*>(&child);
+
+ Layer* lowestPriorityChild = &child;
+ for (auto&& it : baseOutput.GetConnections())
{
- // Bypass sibling. It will be removed as it's left unconnected.
- auto siblingOut = sibling.BeginOutputSlots();
- for (auto childOut = child.BeginOutputSlots(); childOut != child.EndOutputSlots(); ++childOut)
+ Layer* sibling = &it->GetOwningLayer();
+ if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling))
{
- siblingOut->MoveAllConnections(*childOut);
- ++siblingOut;
+ if (sibling->GetPriority() < lowestPriorityChild->GetPriority())
+ {
+ std::swap(sibling, lowestPriorityChild);
+ }
+ // Bypass sibling. It will be removed as it's left unconnected.
+ auto siblingOut = sibling->BeginOutputSlots();
+ for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots();
+ lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut)
+ {
+ siblingOut->MoveAllConnections(*lowestPriorityChildOut);
+ ++siblingOut;
+ }
}
}
}