aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphViewSelector.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/SubgraphViewSelector.cpp')
-rw-r--r--src/armnn/SubgraphViewSelector.cpp62
1 files changed, 36 insertions, 26 deletions
diff --git a/src/armnn/SubgraphViewSelector.cpp b/src/armnn/SubgraphViewSelector.cpp
index 21fbb7cd80..e2c5f911a0 100644
--- a/src/armnn/SubgraphViewSelector.cpp
+++ b/src/armnn/SubgraphViewSelector.cpp
@@ -176,7 +176,7 @@ private:
/// Intermediate data structure to store information associated with a particular layer.
struct LayerSelectionInfo
{
- using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>;
+ using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>;
using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
@@ -193,9 +193,11 @@ struct LayerSelectionInfo
}
void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
- SubgraphView::InputSlots& inputSlots)
+ SubgraphView::IInputSlots& inputSlots)
{
- for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot)
+ for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginInputSlots();
+ slot != PolymorphicDowncast<Layer*>(m_Layer)->EndInputSlots();
+ ++slot)
{
OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
ARMNN_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
@@ -218,9 +220,11 @@ struct LayerSelectionInfo
}
void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
- SubgraphView::OutputSlots& outputSlots)
+ SubgraphView::IOutputSlots& outputSlots)
{
- for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
+ for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginOutputSlots();
+ slot != PolymorphicDowncast<Layer*>(m_Layer)->EndOutputSlots();
+ ++slot)
{
for (InputSlot* childLayerInputSlot : slot->GetConnections())
{
@@ -240,7 +244,7 @@ struct LayerSelectionInfo
}
}
- Layer* m_Layer;
+ IConnectableLayer* m_Layer;
/// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true.
/// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph -
/// see the description of the PartialSubgraph class.
@@ -264,7 +268,7 @@ void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
LayerSelectionInfo& layerInfo,
Delegate function)
{
- Layer& layer = *layerInfo.m_Layer;
+ Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
for (auto inputSlot : layer.GetInputSlots())
{
@@ -285,7 +289,7 @@ void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
LayerSelectionInfo& layerInfo,
Delegate function)
{
- Layer& layer= *layerInfo.m_Layer;
+ Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
for (auto& outputSlot : layer.GetOutputSlots())
{
@@ -387,9 +391,11 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
LayerSelectionInfo::LayerInfoContainer layerInfos;
LayerSelectionInfo::LayerInfoQueue processQueue;
- for (auto& layer : subgraph)
+ const SubgraphView::IConnectableLayers& subgraphLayers = subgraph.GetIConnectableLayers();
+ for (auto& layer : subgraphLayers)
{
- auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector});
+
+ auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{PolymorphicDowncast<Layer*>(layer), selector});
LayerSelectionInfo& layerInfo = emplaced.first->second;
// Start with Input type layers
@@ -399,10 +405,10 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
}
}
- const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots();
+ const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots();
for (auto& inputSlot : subgraphInputSlots)
{
- Layer& layer = inputSlot->GetOwningLayer();
+ Layer& layer = PolymorphicDowncast<InputSlot*>(inputSlot)->GetOwningLayer();
auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
LayerSelectionInfo& layerInfo = emplaced.first->second;
@@ -463,9 +469,9 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
Subgraphs result;
for (auto& splitGraph : splitMap)
{
- SubgraphView::InputSlots inputs;
- SubgraphView::OutputSlots outputs;
- SubgraphView::Layers layers;
+ SubgraphView::IInputSlots inputs;
+ SubgraphView::IOutputSlots outputs;
+ SubgraphView::IConnectableLayers layers;
for (auto&& infoPtr : splitGraph.second)
{
infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
@@ -475,24 +481,28 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
// Sort lists into deterministic order, not relying on pointer values which may be different on each execution.
// This makes debugging the optimised graph much easier as subsequent stages can also be deterministic.
- std::sort(inputs.begin(), inputs.end(), [](const InputSlot* a, const InputSlot* b)
+ std::sort(inputs.begin(), inputs.end(), [](const IInputSlot* a, const IInputSlot* b)
{
- const LayerGuid guidA = a->GetOwningLayer().GetGuid();
- const LayerGuid guidB = b->GetOwningLayer().GetGuid();
+ auto* castA = PolymorphicDowncast<const InputSlot*>(a);
+ auto* castB = PolymorphicDowncast<const InputSlot*>(b);
+ const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
+ const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
if (guidA < guidB)
{
return true;
}
else if (guidA == guidB)
{
- return (a->GetSlotIndex() < b->GetSlotIndex());
+ return (castA->GetSlotIndex() < castB->GetSlotIndex());
}
return false;
});
- std::sort(outputs.begin(), outputs.end(), [](const OutputSlot* a, const OutputSlot* b)
+ std::sort(outputs.begin(), outputs.end(), [](const IOutputSlot* a, const IOutputSlot* b)
{
- const LayerGuid guidA = a->GetOwningLayer().GetGuid();
- const LayerGuid guidB = b->GetOwningLayer().GetGuid();
+ auto* castA = PolymorphicDowncast<const OutputSlot*>(a);
+ auto* castB = PolymorphicDowncast<const OutputSlot*>(b);
+ const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
+ const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
if (guidA < guidB)
{
return true;
@@ -503,12 +513,12 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto
}
return false;
});
- layers.sort([](const Layer* a, const Layer* b) { return a->GetGuid() < b->GetGuid(); });
+ layers.sort([](const IConnectableLayer* a, const IConnectableLayer* b) { return a->GetGuid() < b->GetGuid(); });
// Create a new sub-graph with the new lists of input/output slots and layer
- result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs),
- std::move(outputs),
- std::move(layers)));
+ result.emplace_back(std::make_unique<SubgraphView>(std::move(layers),
+ std::move(inputs),
+ std::move(outputs)));
}
return result;