aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Optimizer.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/Optimizer.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnn/Optimizer.cpp')
-rw-r--r--src/armnn/Optimizer.cpp49
1 files changed, 28 insertions, 21 deletions
diff --git a/src/armnn/Optimizer.cpp b/src/armnn/Optimizer.cpp
index 9b76c7fa72..630aa1a27b 100644
--- a/src/armnn/Optimizer.cpp
+++ b/src/armnn/Optimizer.cpp
@@ -3,6 +3,7 @@
// See LICENSE file in the project root for full license information.
//
#include "Optimizer.hpp"
+#include "Observable.hpp"
#include "optimizations/All.hpp"
namespace armnn
@@ -10,44 +11,50 @@ namespace armnn
Optimizer::Optimizer()
{
- // Add optimizations here
- static optimizations::SquashEqualPermuteSiblings squashEqualPermuteSiblings;
- static optimizations::SquashEqualReshapeSiblings squashEqualReshapeSiblings;
- static optimizations::OptimizeInversePermutes optimizeInversePermutes;
- static optimizations::MovePermuteUp movePermuteUp;
- static optimizations::PermuteAsReshape permuteAsReshape;
- static optimizations::OptimizeConsecutiveReshapes optimizeConsecutiveReshapes;
-
- // Set optimizations in desired order
- m_Optimizations = {&squashEqualPermuteSiblings,
- &squashEqualReshapeSiblings,
- &optimizeInversePermutes,
- &movePermuteUp,
- &permuteAsReshape,
- &optimizeConsecutiveReshapes,
- };
}
-void Optimizer::Optimize(Graph& graph)
+void Optimizer::Pass(Graph& graph, const Optimizations& optimizations)
{
- Optimizer optimizer;
+ // Create observables to observe changes to the graph
+ AddedLayerObservable addedLayerObservable(graph);
+ ErasedLayerNamesObservable erasedLayerNamesObservable(graph);
+
+ bool graphNeedsSorting = false;
auto it = graph.TopologicalSort().end();
- // Call TopologicalSort() in every iteration to re-order the list in case layers where added/removed.
+
+ // Calls TopologicalSort() for every iteration to re-order the list in case layers were added/removed.
while (it != graph.TopologicalSort().begin())
{
--it;
- for (auto&& optimization : optimizer.m_Optimizations)
+ for (auto&& optimization : optimizations)
{
optimization->Run(graph, **it);
if ((*it)->IsOutputUnconnected())
{
it = graph.EraseLayer(it);
+ graphNeedsSorting = true;
+ }
+
+ // Add the names of erased layers as related layers to the new added layers
+ for (auto& erasedLayerName : erasedLayerNamesObservable)
+ {
+ for (auto& addedLayer : addedLayerObservable)
+ {
+ addedLayer->AddRelatedLayerName(erasedLayerName);
+ }
+ }
+
+ erasedLayerNamesObservable.Clear();
+ addedLayerObservable.Clear();
+
+ if (graphNeedsSorting)
+ {
+ graphNeedsSorting = false;
break;
}
}
}
}
-
} // namespace armnn