diff options
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r-- | src/armnn/Graph.cpp | 36 |
1 files changed, 23 insertions, 13 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index 6d24e50bdc..cdb323432c 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -445,10 +445,13 @@ void Graph::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substi void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph) { // Look through each layer in the new subgraph and add any that are not already a member of this graph - substituteSubgraph.ForEachLayer([this](Layer* layer) + substituteSubgraph.ForEachIConnectableLayer([this](IConnectableLayer* iConnectableLayer) { - if (std::find(std::begin(m_Layers), std::end(m_Layers), layer) == std::end(m_Layers)) + if (std::find(std::begin(m_Layers), + std::end(m_Layers), + iConnectableLayer) == std::end(m_Layers)) { + auto layer = PolymorphicDowncast<Layer*>(iConnectableLayer); layer->Reparent(*this, m_Layers.end()); m_LayersInOrder = false; } @@ -461,24 +464,26 @@ void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& subst void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const SubgraphView& substituteSubgraph) { - ARMNN_ASSERT_MSG(!substituteSubgraph.GetLayers().empty(), "New sub-graph used for substitution must not be empty"); + ARMNN_ASSERT_MSG(!substituteSubgraph.GetIConnectableLayers().empty(), + "New sub-graph used for substitution must not be empty"); - const SubgraphView::Layers& substituteSubgraphLayers = substituteSubgraph.GetLayers(); - std::for_each(substituteSubgraphLayers.begin(), substituteSubgraphLayers.end(), [&](Layer* layer) + const SubgraphView::IConnectableLayers& substituteSubgraphLayers = substituteSubgraph.GetIConnectableLayers(); + std::for_each(substituteSubgraphLayers.begin(), substituteSubgraphLayers.end(), [&](IConnectableLayer* layer) { IgnoreUnused(layer); + layer = PolymorphicDowncast<Layer*>(layer); ARMNN_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), layer) != m_Layers.end(), "Substitute layer is not a member of graph"); }); - const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots(); - const SubgraphView::OutputSlots& subgraphOutputSlots = subgraph.GetOutputSlots(); + const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots(); + const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraph.GetIOutputSlots(); unsigned int subgraphNumInputSlots = armnn::numeric_cast<unsigned int>(subgraphInputSlots.size()); unsigned int subgraphNumOutputSlots = armnn::numeric_cast<unsigned int>(subgraphOutputSlots.size()); - const SubgraphView::InputSlots& substituteSubgraphInputSlots = substituteSubgraph.GetInputSlots(); - const SubgraphView::OutputSlots& substituteSubgraphOutputSlots = substituteSubgraph.GetOutputSlots(); + const SubgraphView::IInputSlots& substituteSubgraphInputSlots = substituteSubgraph.GetIInputSlots(); + const SubgraphView::IOutputSlots& substituteSubgraphOutputSlots = substituteSubgraph.GetIOutputSlots(); ARMNN_ASSERT(subgraphNumInputSlots == substituteSubgraphInputSlots.size()); ARMNN_ASSERT(subgraphNumOutputSlots == substituteSubgraphOutputSlots.size()); @@ -488,7 +493,7 @@ void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const Subgr // Step 1: process input slots for (unsigned int inputSlotIdx = 0; inputSlotIdx < subgraphNumInputSlots; ++inputSlotIdx) { - InputSlot* subgraphInputSlot = subgraphInputSlots.at(inputSlotIdx); + IInputSlot* subgraphInputSlot = subgraphInputSlots.at(inputSlotIdx); ARMNN_ASSERT(subgraphInputSlot); IOutputSlot* connectedOutputSlot = subgraphInputSlot->GetConnection(); @@ -503,19 +508,24 @@ void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const Subgr // Step 2: process output slots for(unsigned int outputSlotIdx = 0; outputSlotIdx < subgraphNumOutputSlots; ++outputSlotIdx) { - OutputSlot* subgraphOutputSlot = subgraphOutputSlots.at(outputSlotIdx); + auto subgraphOutputSlot = + PolymorphicDowncast<OutputSlot*>(subgraphOutputSlots.at(outputSlotIdx)); ARMNN_ASSERT(subgraphOutputSlot); - OutputSlot* substituteOutputSlot = substituteSubgraphOutputSlots.at(outputSlotIdx); + auto substituteOutputSlot = + PolymorphicDowncast<OutputSlot*>(substituteSubgraphOutputSlots.at(outputSlotIdx)); ARMNN_ASSERT(substituteOutputSlot); + subgraphOutputSlot->MoveAllConnections(*substituteOutputSlot); } } void Graph::EraseSubgraphLayers(SubgraphView &subgraph) { - for (auto layer : subgraph.GetLayers()) + + for (auto iConnectableLayer : subgraph.GetIConnectableLayers()) { + auto layer = PolymorphicDowncast<Layer*>(iConnectableLayer); EraseLayer(layer); } subgraph.Clear(); |