diff options
Diffstat (limited to 'src/armnn/SubgraphViewSelector.cpp')
-rw-r--r-- | src/armnn/SubgraphViewSelector.cpp | 62 |
1 files changed, 36 insertions, 26 deletions
diff --git a/src/armnn/SubgraphViewSelector.cpp b/src/armnn/SubgraphViewSelector.cpp index 21fbb7cd80..e2c5f911a0 100644 --- a/src/armnn/SubgraphViewSelector.cpp +++ b/src/armnn/SubgraphViewSelector.cpp @@ -176,7 +176,7 @@ private: /// Intermediate data structure to store information associated with a particular layer. struct LayerSelectionInfo { - using LayerInfoContainer = std::map<Layer*, LayerSelectionInfo>; + using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>; using LayerInfoQueue = std::queue<LayerSelectionInfo*>; LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector) @@ -193,9 +193,11 @@ struct LayerSelectionInfo } void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos, - SubgraphView::InputSlots& inputSlots) + SubgraphView::IInputSlots& inputSlots) { - for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot) + for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginInputSlots(); + slot != PolymorphicDowncast<Layer*>(m_Layer)->EndInputSlots(); + ++slot) { OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot(); ARMNN_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here."); @@ -218,9 +220,11 @@ struct LayerSelectionInfo } void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos, - SubgraphView::OutputSlots& outputSlots) + SubgraphView::IOutputSlots& outputSlots) { - for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot) + for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginOutputSlots(); + slot != PolymorphicDowncast<Layer*>(m_Layer)->EndOutputSlots(); + ++slot) { for (InputSlot* childLayerInputSlot : slot->GetConnections()) { @@ -240,7 +244,7 @@ struct LayerSelectionInfo } } - Layer* m_Layer; + IConnectableLayer* m_Layer; /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true. /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph - /// see the description of the PartialSubgraph class. @@ -264,7 +268,7 @@ void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo, Delegate function) { - Layer& layer = *layerInfo.m_Layer; + Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer); for (auto inputSlot : layer.GetInputSlots()) { @@ -285,7 +289,7 @@ void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo, Delegate function) { - Layer& layer= *layerInfo.m_Layer; + Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer); for (auto& outputSlot : layer.GetOutputSlots()) { @@ -387,9 +391,11 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto LayerSelectionInfo::LayerInfoContainer layerInfos; LayerSelectionInfo::LayerInfoQueue processQueue; - for (auto& layer : subgraph) + const SubgraphView::IConnectableLayers& subgraphLayers = subgraph.GetIConnectableLayers(); + for (auto& layer : subgraphLayers) { - auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{layer, selector}); + + auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{PolymorphicDowncast<Layer*>(layer), selector}); LayerSelectionInfo& layerInfo = emplaced.first->second; // Start with Input type layers @@ -399,10 +405,10 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto } } - const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots(); + const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots(); for (auto& inputSlot : subgraphInputSlots) { - Layer& layer = inputSlot->GetOwningLayer(); + Layer& layer = PolymorphicDowncast<InputSlot*>(inputSlot)->GetOwningLayer(); auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector}); LayerSelectionInfo& layerInfo = emplaced.first->second; @@ -463,9 +469,9 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto Subgraphs result; for (auto& splitGraph : splitMap) { - SubgraphView::InputSlots inputs; - SubgraphView::OutputSlots outputs; - SubgraphView::Layers layers; + SubgraphView::IInputSlots inputs; + SubgraphView::IOutputSlots outputs; + SubgraphView::IConnectableLayers layers; for (auto&& infoPtr : splitGraph.second) { infoPtr->CollectNonSelectedInputs(layerInfos, inputs); @@ -475,24 +481,28 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto // Sort lists into deterministic order, not relying on pointer values which may be different on each execution. // This makes debugging the optimised graph much easier as subsequent stages can also be deterministic. - std::sort(inputs.begin(), inputs.end(), [](const InputSlot* a, const InputSlot* b) + std::sort(inputs.begin(), inputs.end(), [](const IInputSlot* a, const IInputSlot* b) { - const LayerGuid guidA = a->GetOwningLayer().GetGuid(); - const LayerGuid guidB = b->GetOwningLayer().GetGuid(); + auto* castA = PolymorphicDowncast<const InputSlot*>(a); + auto* castB = PolymorphicDowncast<const InputSlot*>(b); + const LayerGuid guidA = castA->GetOwningLayer().GetGuid(); + const LayerGuid guidB = castB->GetOwningLayer().GetGuid(); if (guidA < guidB) { return true; } else if (guidA == guidB) { - return (a->GetSlotIndex() < b->GetSlotIndex()); + return (castA->GetSlotIndex() < castB->GetSlotIndex()); } return false; }); - std::sort(outputs.begin(), outputs.end(), [](const OutputSlot* a, const OutputSlot* b) + std::sort(outputs.begin(), outputs.end(), [](const IOutputSlot* a, const IOutputSlot* b) { - const LayerGuid guidA = a->GetOwningLayer().GetGuid(); - const LayerGuid guidB = b->GetOwningLayer().GetGuid(); + auto* castA = PolymorphicDowncast<const OutputSlot*>(a); + auto* castB = PolymorphicDowncast<const OutputSlot*>(b); + const LayerGuid guidA = castA->GetOwningLayer().GetGuid(); + const LayerGuid guidB = castB->GetOwningLayer().GetGuid(); if (guidA < guidB) { return true; @@ -503,12 +513,12 @@ SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelecto } return false; }); - layers.sort([](const Layer* a, const Layer* b) { return a->GetGuid() < b->GetGuid(); }); + layers.sort([](const IConnectableLayer* a, const IConnectableLayer* b) { return a->GetGuid() < b->GetGuid(); }); // Create a new sub-graph with the new lists of input/output slots and layer - result.emplace_back(std::make_unique<SubgraphView>(std::move(inputs), - std::move(outputs), - std::move(layers))); + result.emplace_back(std::make_unique<SubgraphView>(std::move(layers), + std::move(inputs), + std::move(outputs))); } return result; |