aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubGraphSelector.cpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-01-24 14:06:23 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-30 14:03:28 +0000
commitadddddb6cbcb777d92a8c464c9ad0cb9aecc76a3 (patch)
treeb15de32bf9f8612f66e1ae23d2f8009e80e7d0e6 /src/armnn/SubGraphSelector.cpp
parentd089b74bebbcc8518fb0f4eacb7e6569ae170199 (diff)
downloadarmnn-adddddb6cbcb777d92a8c464c9ad0cb9aecc76a3.tar.gz
IVGCVSW-2458 Refactor the Optimize function (Network.cpp) so that
subgraphs are optimized by the backends * Added a new method OptimizeSubGraph to the backend interface * Refactored the Optimize function so that the backend-specific optimization is performed by the backend itself (through the new OptimizeSubGraph interface method) * Added a new ApplyBackendOptimizations function to apply the new changes * Added some new convenient constructors to the SubGraph class * Added AddLayer method and a pointer to the parent graph to the SubGraph class * Updated the sub-graph unit tests to match the changes * Added SelectSubGraphs and ReplaceSubGraphConnections overloads that work with sub-graphs * Removed unused code and minor refactoring where necessary Change-Id: I46181794c6a9e3b10558944f804e06a8f693a6d0
Diffstat (limited to 'src/armnn/SubGraphSelector.cpp')
-rw-r--r--src/armnn/SubGraphSelector.cpp37
1 files changed, 21 insertions, 16 deletions
diff --git a/src/armnn/SubGraphSelector.cpp b/src/armnn/SubGraphSelector.cpp
index d0542fd41f..4abf01c88f 100644
--- a/src/armnn/SubGraphSelector.cpp
+++ b/src/armnn/SubGraphSelector.cpp
@@ -69,25 +69,25 @@ struct LayerSelectionInfo
return m_Layer->GetType() == armnn::LayerType::Input;
}
- void CollectNonSelectedInputs(SubGraph::InputSlots& slots,
+ void CollectNonSelectedInputs(SubGraph::InputSlots& inputSlots,
const SubGraphSelector::LayerSelectorFunction& selector)
{
for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot)
{
OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
- BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The slots must be connected here.");
+ BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
if (parentLayerOutputSlot)
{
Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
if (selector(parentLayer) == false)
{
- slots.push_back(&(*slot));
+ inputSlots.push_back(&(*slot));
}
}
}
}
- void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& slots,
+ void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& outputSlots,
const SubGraphSelector::LayerSelectorFunction& selector)
{
for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
@@ -97,7 +97,7 @@ struct LayerSelectionInfo
Layer& childLayer = childLayerInputSlot->GetOwningLayer();
if (selector(childLayer) == false)
{
- slots.push_back(&(*slot));
+ outputSlots.push_back(&(*slot));
}
}
}
@@ -112,12 +112,18 @@ struct LayerSelectionInfo
} // namespace <anonymous>
SubGraphSelector::SubGraphs
-SubGraphSelector::SelectSubGraphs(Graph& graph,
- const LayerSelectorFunction& selector)
+SubGraphSelector::SelectSubGraphs(Graph& graph, const LayerSelectorFunction& selector)
+{
+ SubGraph subGraph(graph);
+ return SubGraphSelector::SelectSubGraphs(subGraph, selector);
+}
+
+SubGraphSelector::SubGraphs
+SubGraphSelector::SelectSubGraphs(SubGraph& subGraph, const LayerSelectorFunction& selector)
{
LayerSelectionInfo::LayerInfoContainer layerInfo;
- for (auto& layer : graph)
+ for (auto& layer : subGraph)
{
layerInfo.emplace(layer, LayerSelectionInfo{layer, selector});
}
@@ -127,7 +133,7 @@ SubGraphSelector::SelectSubGraphs(Graph& graph,
{
if (info.second.IsInputLayer())
{
- // for each input layer we mark the graph where subgraph
+ // For each input layer we mark the graph where subgraph
// splits need to happen because of the dependency between
// the selected and non-selected nodes
info.second.MarkChildrenSplits(layerInfo, splitNo, false);
@@ -159,20 +165,19 @@ SubGraphSelector::SelectSubGraphs(Graph& graph,
{
if (splitGraph.second.empty() == false)
{
- SubGraph::OutputSlots outputs;
SubGraph::InputSlots inputs;
+ SubGraph::OutputSlots outputs;
SubGraph::Layers layers;
for (auto&& infoPtr : splitGraph.second)
{
- infoPtr->CollectNonSelectedOutputSlots(outputs, selector);
infoPtr->CollectNonSelectedInputs(inputs, selector);
+ infoPtr->CollectNonSelectedOutputSlots(outputs, selector);
layers.push_back(infoPtr->m_Layer);
}
- result.emplace_back(
- std::make_unique<SubGraph>(
- std::move(inputs),
- std::move(outputs),
- std::move(layers)));
+ result.emplace_back(std::make_unique<SubGraph>(subGraph,
+ std::move(inputs),
+ std::move(outputs),
+ std::move(layers)));
}
}