diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-24 14:06:23 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-30 14:03:28 +0000 |
commit | adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch) | |
tree | b15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/SubGraphSelector.cpp | |
parent | d089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff) | |
download | armnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz |
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends
* Added a new method OptimizeSubGraph to the backend interface
* Refactored the Optimize function so that the backend-specific
optimization is performed by the backend itself (through the new
OptimizeSubGraph interface method)
* Added a new ApplyBackendOptimizations function to apply the new
changes
* Added some new convenient constructors to the SubGraph class
* Added AddLayer method and a pointer to the parent graph to the
SubGraph class
* Updated the sub-graph unit tests to match the changes
* Added SelectSubGraphs and ReplaceSubGraphConnections overloads
that work with sub-graphs
* Removed unused code and minor refactoring where necessary
Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/SubGraphSelector.cpp')
-rw-r--r-- | src/armnn/SubGraphSelector.cpp | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/src/armnn/SubGraphSelector.cpp b/src/armnn/SubGraphSelector.cpp index d0542fd41f..4abf01c88f 100644 --- a/src/armnn/SubGraphSelector.cpp +++ b/src/armnn/SubGraphSelector.cpp @@ -69,25 +69,25 @@ struct LayerSelectionInfo return m_Layer->GetType() == armnn::LayerType::Input; } - void CollectNonSelectedInputs(SubGraph::InputSlots& slots, + void CollectNonSelectedInputs(SubGraph::InputSlots& inputSlots, const SubGraphSelector::LayerSelectorFunction& selector) { for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot) { OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot(); - BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The slots must be connected here."); + BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here."); if (parentLayerOutputSlot) { Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer(); if (selector(parentLayer) == false) { - slots.push_back(&(*slot)); + inputSlots.push_back(&(*slot)); } } } } - void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& slots, + void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& outputSlots, const SubGraphSelector::LayerSelectorFunction& selector) { for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot) @@ -97,7 +97,7 @@ struct LayerSelectionInfo Layer& childLayer = childLayerInputSlot->GetOwningLayer(); if (selector(childLayer) == false) { - slots.push_back(&(*slot)); + outputSlots.push_back(&(*slot)); } } } @@ -112,12 +112,18 @@ struct LayerSelectionInfo } // namespace <anonymous> SubGraphSelector::SubGraphs -SubGraphSelector::SelectSubGraphs(Graph& graph, - const LayerSelectorFunction& selector) +SubGraphSelector::SelectSubGraphs(Graph& graph, const LayerSelectorFunction& selector) +{ + SubGraph subGraph(graph); + return SubGraphSelector::SelectSubGraphs(subGraph, selector); +} + +SubGraphSelector::SubGraphs +SubGraphSelector::SelectSubGraphs(SubGraph& subGraph, const LayerSelectorFunction& selector) { LayerSelectionInfo::LayerInfoContainer layerInfo; - for (auto& layer : graph) + for (auto& layer : subGraph) { layerInfo.emplace(layer, LayerSelectionInfo{layer, selector}); } @@ -127,7 +133,7 @@ SubGraphSelector::SelectSubGraphs(Graph& graph, { if (info.second.IsInputLayer()) { - // for each input layer we mark the graph where subgraph + // For each input layer we mark the graph where subgraph // splits need to happen because of the dependency between // the selected and non-selected nodes info.second.MarkChildrenSplits(layerInfo, splitNo, false); @@ -159,20 +165,19 @@ SubGraphSelector::SelectSubGraphs(Graph& graph, { if (splitGraph.second.empty() == false) { - SubGraph::OutputSlots outputs; SubGraph::InputSlots inputs; + SubGraph::OutputSlots outputs; SubGraph::Layers layers; for (auto&& infoPtr : splitGraph.second) { - infoPtr->CollectNonSelectedOutputSlots(outputs, selector); infoPtr->CollectNonSelectedInputs(inputs, selector); + infoPtr->CollectNonSelectedOutputSlots(outputs, selector); layers.push_back(infoPtr->m_Layer); } - result.emplace_back( - std::make_unique<SubGraph>( - std::move(inputs), - std::move(outputs), - std::move(layers))); + result.emplace_back(std::make_unique<SubGraph>(subGraph, + std::move(inputs), + std::move(outputs), + std::move(layers))); } } |