diff options
Diffstat (limited to 'src/armnn/SubGraphSelector.cpp')
-rw-r--r-- | src/armnn/SubGraphSelector.cpp | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/src/armnn/SubGraphSelector.cpp b/src/armnn/SubGraphSelector.cpp index d0542fd41f..4abf01c88f 100644 --- a/src/armnn/SubGraphSelector.cpp +++ b/src/armnn/SubGraphSelector.cpp @@ -69,25 +69,25 @@ struct LayerSelectionInfo return m_Layer->GetType() == armnn::LayerType::Input; } - void CollectNonSelectedInputs(SubGraph::InputSlots& slots, + void CollectNonSelectedInputs(SubGraph::InputSlots& inputSlots, const SubGraphSelector::LayerSelectorFunction& selector) { for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot) { OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot(); - BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The slots must be connected here."); + BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here."); if (parentLayerOutputSlot) { Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer(); if (selector(parentLayer) == false) { - slots.push_back(&(*slot)); + inputSlots.push_back(&(*slot)); } } } } - void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& slots, + void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& outputSlots, const SubGraphSelector::LayerSelectorFunction& selector) { for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot) @@ -97,7 +97,7 @@ struct LayerSelectionInfo Layer& childLayer = childLayerInputSlot->GetOwningLayer(); if (selector(childLayer) == false) { - slots.push_back(&(*slot)); + outputSlots.push_back(&(*slot)); } } } @@ -112,12 +112,18 @@ struct LayerSelectionInfo } // namespace <anonymous> SubGraphSelector::SubGraphs -SubGraphSelector::SelectSubGraphs(Graph& graph, - const LayerSelectorFunction& selector) +SubGraphSelector::SelectSubGraphs(Graph& graph, const LayerSelectorFunction& selector) +{ + SubGraph subGraph(graph); + return SubGraphSelector::SelectSubGraphs(subGraph, selector); +} + +SubGraphSelector::SubGraphs +SubGraphSelector::SelectSubGraphs(SubGraph& subGraph, const LayerSelectorFunction& selector) { LayerSelectionInfo::LayerInfoContainer layerInfo; - for (auto& layer : graph) + for (auto& layer : subGraph) { layerInfo.emplace(layer, LayerSelectionInfo{layer, selector}); } @@ -127,7 +133,7 @@ SubGraphSelector::SelectSubGraphs(Graph& graph, { if (info.second.IsInputLayer()) { - // for each input layer we mark the graph where subgraph + // For each input layer we mark the graph where subgraph // splits need to happen because of the dependency between // the selected and non-selected nodes info.second.MarkChildrenSplits(layerInfo, splitNo, false); @@ -159,20 +165,19 @@ SubGraphSelector::SelectSubGraphs(Graph& graph, { if (splitGraph.second.empty() == false) { - SubGraph::OutputSlots outputs; SubGraph::InputSlots inputs; + SubGraph::OutputSlots outputs; SubGraph::Layers layers; for (auto&& infoPtr : splitGraph.second) { - infoPtr->CollectNonSelectedOutputSlots(outputs, selector); infoPtr->CollectNonSelectedInputs(inputs, selector); + infoPtr->CollectNonSelectedOutputSlots(outputs, selector); layers.push_back(infoPtr->m_Layer); } - result.emplace_back( - std::make_unique<SubGraph>( - std::move(inputs), - std::move(outputs), - std::move(layers))); + result.emplace_back(std::make_unique<SubGraph>(subGraph, + std::move(inputs), + std::move(outputs), + std::move(layers))); } } |