aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphViewSelector.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/SubgraphViewSelector.cpp')
-rw-r--r--src/armnn/SubgraphViewSelector.cpp32
1 files changed, 25 insertions, 7 deletions
diff --git a/src/armnn/SubgraphViewSelector.cpp b/src/armnn/SubgraphViewSelector.cpp
index cc821ec956..8e4de0b5f8 100644
--- a/src/armnn/SubgraphViewSelector.cpp
+++ b/src/armnn/SubgraphViewSelector.cpp
@@ -7,7 +7,7 @@
#include "Graph.hpp"
#include <boost/assert.hpp>
#include <algorithm>
-#include <unordered_map>
+#include <map>
#include <queue>
namespace armnn
@@ -19,7 +19,7 @@ namespace
struct LayerSelectionInfo
{
using SplitId = uint32_t;
- using LayerInfoContainer = std::unordered_map<Layer*, LayerSelectionInfo>;
+ using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
static constexpr uint32_t InitialSplitId() { return 1; }
@@ -56,7 +56,8 @@ struct LayerSelectionInfo
{
Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
auto parentInfo = layerInfos.find(&parentLayer);
- if (m_SplitId != parentInfo->second.m_SplitId)
+ if (parentInfo == layerInfos.end() ||
+ m_SplitId != parentInfo->second.m_SplitId)
{
inputSlots.push_back(&(*slot));
}
@@ -73,7 +74,8 @@ struct LayerSelectionInfo
{
Layer& childLayer = childLayerInputSlot->GetOwningLayer();
auto childInfo = layerInfos.find(&childLayer);
- if (m_SplitId != childInfo->second.m_SplitId)
+ if (childInfo == layerInfos.end() ||
+ m_SplitId != childInfo->second.m_SplitId)
{
outputSlots.push_back(&(*slot));
}
@@ -112,7 +114,10 @@ void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
Layer& inputLayer = connectedInput->GetOwningLayer();
auto parentInfo = layerInfos.find(&inputLayer);
- function(parentInfo->second);
+ if (parentInfo != layerInfos.end())
+ {
+ function(parentInfo->second);
+ }
}
}
@@ -130,7 +135,10 @@ void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
Layer& childLayer = output->GetOwningLayer();
auto childInfo = layerInfos.find(&childLayer);
- function(childInfo->second);
+ if (childInfo != layerInfos.end())
+ {
+ function(childInfo->second);
+ }
}
}
}
@@ -213,6 +221,16 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
}
}
+ const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots();
+ for (auto& inputSlot : subgraphInputSlots)
+ {
+ Layer& layer = inputSlot->GetOwningLayer();
+ auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
+ LayerSelectionInfo& layerInfo = emplaced.first->second;
+
+ processQueue.push(&layerInfo);
+ }
+
while (!processQueue.empty())
{
LayerSelectionInfo& layerInfo = *processQueue.front();
@@ -246,7 +264,7 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
// Collect all selected layers keyed by split id into a map
using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
- std::unordered_map<uint32_t, SelectionInfoPtrs> splitMap;
+ std::map<uint32_t, SelectionInfoPtrs> splitMap;
for (auto& info : layerInfos)
{
if (info.second.m_IsSelected)