diff options
Diffstat (limited to 'src/armnn/Optimizer.cpp')
-rw-r--r-- | src/armnn/Optimizer.cpp | 49 |
1 files changed, 28 insertions, 21 deletions
diff --git a/src/armnn/Optimizer.cpp b/src/armnn/Optimizer.cpp index 9b76c7fa72..630aa1a27b 100644 --- a/src/armnn/Optimizer.cpp +++ b/src/armnn/Optimizer.cpp @@ -3,6 +3,7 @@ // See LICENSE file in the project root for full license information. // #include "Optimizer.hpp" +#include "Observable.hpp" #include "optimizations/All.hpp" namespace armnn @@ -10,44 +11,50 @@ namespace armnn Optimizer::Optimizer() { - // Add optimizations here - static optimizations::SquashEqualPermuteSiblings squashEqualPermuteSiblings; - static optimizations::SquashEqualReshapeSiblings squashEqualReshapeSiblings; - static optimizations::OptimizeInversePermutes optimizeInversePermutes; - static optimizations::MovePermuteUp movePermuteUp; - static optimizations::PermuteAsReshape permuteAsReshape; - static optimizations::OptimizeConsecutiveReshapes optimizeConsecutiveReshapes; - - // Set optimizations in desired order - m_Optimizations = {&squashEqualPermuteSiblings, - &squashEqualReshapeSiblings, - &optimizeInversePermutes, - &movePermuteUp, - &permuteAsReshape, - &optimizeConsecutiveReshapes, - }; } -void Optimizer::Optimize(Graph& graph) +void Optimizer::Pass(Graph& graph, const Optimizations& optimizations) { - Optimizer optimizer; + // Create observables to observe changes to the graph + AddedLayerObservable addedLayerObservable(graph); + ErasedLayerNamesObservable erasedLayerNamesObservable(graph); + + bool graphNeedsSorting = false; auto it = graph.TopologicalSort().end(); - // Call TopologicalSort() in every iteration to re-order the list in case layers where added/removed. + + // Calls TopologicalSort() for every iteration to re-order the list in case layers were added/removed. while (it != graph.TopologicalSort().begin()) { --it; - for (auto&& optimization : optimizer.m_Optimizations) + for (auto&& optimization : optimizations) { optimization->Run(graph, **it); if ((*it)->IsOutputUnconnected()) { it = graph.EraseLayer(it); + graphNeedsSorting = true; + } + + // Add the names of erased layers as related layers to the new added layers + for (auto& erasedLayerName : erasedLayerNamesObservable) + { + for (auto& addedLayer : addedLayerObservable) + { + addedLayer->AddRelatedLayerName(erasedLayerName); + } + } + + erasedLayerNamesObservable.Clear(); + addedLayerObservable.Clear(); + + if (graphNeedsSorting) + { + graphNeedsSorting = false; break; } } } } - } // namespace armnn |