aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations
diff options
context:
space:
mode:
authorsurmeh01 <surabhi.mehta@arm.com>2018-03-29 16:29:27 +0100
committersurmeh01 <surabhi.mehta@arm.com>2018-03-29 16:29:27 +0100
commitbceff2fb3fc68bb0aa88b886900c34b77340c826 (patch)
treed867d3e090d58d3012dfbbac456e9ea8c7f789bc /src/armnn/optimizations
parent4fcda0101ec3d110c1d6d7bee5c83416b645528a (diff)
downloadarmnn-bceff2fb3fc68bb0aa88b886900c34b77340c826.tar.gz
Release 18.03
Diffstat (limited to 'src/armnn/optimizations')
-rw-r--r--src/armnn/optimizations/Optimization.hpp27
-rw-r--r--src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp4
-rw-r--r--src/armnn/optimizations/SquashEqualSiblings.hpp28
3 files changed, 34 insertions, 25 deletions
diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp
index 89e03ff88d..f81071891b 100644
--- a/src/armnn/optimizations/Optimization.hpp
+++ b/src/armnn/optimizations/Optimization.hpp
@@ -13,7 +13,7 @@ namespace armnn
class Optimization
{
public:
- virtual void Run(Graph& graph, Graph::Iterator& pos) const = 0;
+ virtual void Run(Graph& graph, Layer& base) const = 0;
protected:
~Optimization() = default;
};
@@ -23,22 +23,20 @@ protected:
// (curiously recurring template pattern).
// For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
-/// Wrapper Optimization base class that calls Wrapped::Run for every layer of type BaseType.
-/// - Wrapped class mustn't remove the base layer.
-/// - Base layer is removed if left unconnected after applying the wrapped optimization.
+/// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
+/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
+/// after applying each optimization.
template <typename BaseType, typename Wrapped>
class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
{
public:
using Wrapped::Wrapped;
- void Run(Graph& graph, Graph::Iterator& pos) const override
+ void Run(Graph& graph, Layer& base) const override
{
- Layer* const base = *pos;
-
- if (base->GetType() == LayerEnumOf<BaseType>())
+ if (base.GetType() == LayerEnumOf<BaseType>())
{
- Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(base));
+ Wrapped::Run(graph, *boost::polymorphic_downcast<BaseType*>(&base));
}
}
@@ -46,16 +44,16 @@ protected:
~OptimizeForTypeImpl() = default;
};
-/// Specialization that calls Wrapped::Run for any layer type
+/// Specialization that calls Wrapped::Run() for any layer type
template <typename Wrapped>
class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
{
public:
using Wrapped::Wrapped;
- void Run(Graph& graph, Graph::Iterator& pos) const override
+ void Run(Graph& graph, Layer& base) const override
{
- Wrapped::Run(graph, **pos);
+ Wrapped::Run(graph, base);
}
protected:
@@ -70,9 +68,10 @@ public:
};
/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
-/// - Wrapped class mustn't remove the base layer.
+/// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
+/// after applying each optimization.
/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
-/// - Base and children layers are removed if left unconnected after applying the wrapped optimization.
+/// - Children layers are removed if left unconnected after applying the wrapped optimization.
template <typename BaseType, typename ChildType, typename Wrapped>
class OptimizeForConnectionImpl : public Wrapped
{
diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
index deb49c6884..9a926a57a4 100644
--- a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
+++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
@@ -18,8 +18,8 @@ public:
/// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
void Run(Graph& graph, InputSlot& connection) const
{
- auto& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
- auto& child = connection.GetOwningLayer();
+ Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
+ Layer& child = connection.GetOwningLayer();
BOOST_ASSERT(base.GetType() == LayerType::Reshape);
BOOST_ASSERT(child.GetType() == LayerType::Reshape);
diff --git a/src/armnn/optimizations/SquashEqualSiblings.hpp b/src/armnn/optimizations/SquashEqualSiblings.hpp
index 2dfe91fdcc..c5ce28e723 100644
--- a/src/armnn/optimizations/SquashEqualSiblings.hpp
+++ b/src/armnn/optimizations/SquashEqualSiblings.hpp
@@ -26,19 +26,29 @@ public:
if (!child.IsOutputUnconnected())
{
OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
- auto& comparableChild = *boost::polymorphic_downcast<Comparable*>(&child);
- for (auto&& it : baseOutput.GetConnections())
+ if (baseOutput.GetNumConnections() > 1)
{
- Layer& sibling = it->GetOwningLayer();
- if ((&sibling != &child) && comparableChild.IsEqual(sibling))
+ auto& comparableChild = *boost::polymorphic_downcast<Comparable*>(&child);
+
+ Layer* lowestPriorityChild = &child;
+ for (auto&& it : baseOutput.GetConnections())
{
- // Bypass sibling. It will be removed as it's left unconnected.
- auto siblingOut = sibling.BeginOutputSlots();
- for (auto childOut = child.BeginOutputSlots(); childOut != child.EndOutputSlots(); ++childOut)
+ Layer* sibling = &it->GetOwningLayer();
+ if ((sibling != lowestPriorityChild) && comparableChild.IsEqual(*sibling))
{
- siblingOut->MoveAllConnections(*childOut);
- ++siblingOut;
+ if (sibling->GetPriority() < lowestPriorityChild->GetPriority())
+ {
+ std::swap(sibling, lowestPriorityChild);
+ }
+ // Bypass sibling. It will be removed as it's left unconnected.
+ auto siblingOut = sibling->BeginOutputSlots();
+ for (auto lowestPriorityChildOut = lowestPriorityChild->BeginOutputSlots();
+ lowestPriorityChildOut != lowestPriorityChild->EndOutputSlots(); ++lowestPriorityChildOut)
+ {
+ siblingOut->MoveAllConnections(*lowestPriorityChildOut);
+ ++siblingOut;
+ }
}
}
}