diff options
Diffstat (limited to 'src/armnn/SubGraphSelector.cpp')
-rw-r--r-- | src/armnn/SubGraphSelector.cpp | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/src/armnn/SubGraphSelector.cpp b/src/armnn/SubGraphSelector.cpp new file mode 100644 index 0000000000..b87e2b73b1 --- /dev/null +++ b/src/armnn/SubGraphSelector.cpp @@ -0,0 +1,182 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "SubGraphSelector.hpp" +#include "Graph.hpp" +#include <boost/assert.hpp> +#include <algorithm> +#include <unordered_map> + +namespace armnn +{ + +namespace +{ + +struct LayerSelectionInfo +{ + using LayerInfoContainer = std::unordered_map<Layer*, LayerSelectionInfo>; + static constexpr uint32_t InitialSplitId() { return 1; } + + LayerSelectionInfo(Layer* layer, const SubGraphSelector::LayerSelectorFunction& selector) + : m_Layer{layer} + , m_SplitId{0} + , m_IsSelected{selector(*layer)} + { + // fill topology information by storing direct children + for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot) + { + for (InputSlot* childLayerInputSlot : slot->GetConnections()) + { + Layer& childLayer = childLayerInputSlot->GetOwningLayer(); + m_DirectChildren.push_back(&childLayer); + } + } + } + + void MarkChildrenSplits(LayerInfoContainer& network, + uint32_t splitId, + bool prevSelected) + { + if (m_SplitId < splitId) + { + m_SplitId = splitId; + } + + // introduce a new split point at all non-selected points, but only if the + // previous point was selected. this prevents creating a new subgraph at + // every non-selected layer + if (!m_IsSelected && prevSelected) + { + ++m_SplitId; + } + + for (auto& layer : m_DirectChildren) + { + auto it = network.find(layer); + BOOST_ASSERT_MSG(it != network.end(), "All layers must be part of the topology."); + if (it != network.end()) + { + it->second.MarkChildrenSplits(network, m_SplitId, m_IsSelected); + } + } + } + + bool IsInputLayer() const + { + return m_Layer->GetType() == armnn::LayerType::Input; + } + + void CollectNonSelectedInputs(SubGraph::InputSlots& slots, + const SubGraphSelector::LayerSelectorFunction& selector) + { + for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot) + { + OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot(); + BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The slots must be connected here."); + if (parentLayerOutputSlot) + { + Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer(); + if (selector(parentLayer) == false) + { + slots.push_back(&(*slot)); + } + } + } + } + + void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& slots, + const SubGraphSelector::LayerSelectorFunction& selector) + { + for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot) + { + for (InputSlot* childLayerInputSlot : slot->GetConnections()) + { + Layer& childLayer = childLayerInputSlot->GetOwningLayer(); + if (selector(childLayer) == false) + { + slots.push_back(&(*slot)); + } + } + } + } + + std::vector<Layer*> m_DirectChildren; + Layer* m_Layer; + uint32_t m_SplitId; + bool m_IsSelected; +}; + +} // namespace <anonymous> + +SubGraphSelector::SubGraphs +SubGraphSelector::SelectSubGraphs(Graph& graph, + const LayerSelectorFunction& selector) +{ + LayerSelectionInfo::LayerInfoContainer layerInfo; + + for (auto& layer : graph) + { + layerInfo.emplace(layer, LayerSelectionInfo{layer, selector}); + } + + uint32_t splitNo = LayerSelectionInfo::InitialSplitId(); + for (auto& info : layerInfo) + { + if (info.second.IsInputLayer()) + { + // for each input layer we mark the graph where subgraph + // splits need to happen because of the dependency between + // the selected and non-selected nodes + info.second.MarkChildrenSplits(layerInfo, splitNo, false); + } + } + + // Collect all selected layers keyed by split id into a map + using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>; + std::unordered_map<uint32_t, SelectionInfoPtrs> splitMap; + for (auto& info : layerInfo) + { + if (info.second.m_IsSelected) + { + auto it = splitMap.find(info.second.m_SplitId); + if (it == splitMap.end()) + { + splitMap.insert(std::make_pair(info.second.m_SplitId, SelectionInfoPtrs{&info.second})); + } + else + { + it->second.push_back(&info.second); + } + } + } + + // Now each non-empty split id represents a subgraph + SubGraphs result; + for (auto& splitGraph : splitMap) + { + if (splitGraph.second.empty() == false) + { + SubGraph::OutputSlots outputs; + SubGraph::InputSlots inputs; + SubGraph::Layers layers; + for (auto&& infoPtr : splitGraph.second) + { + infoPtr->CollectNonSelectedOutputSlots(outputs, selector); + infoPtr->CollectNonSelectedInputs(inputs, selector); + layers.insert(infoPtr->m_Layer); + } + result.emplace_back( + std::make_unique<SubGraph>( + std::move(inputs), + std::move(outputs), + std::move(layers))); + } + } + + return result; +} + +} // namespace armnn |