aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r--src/armnn/Graph.cpp36
1 files changed, 23 insertions, 13 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index 6d24e50bdc..cdb323432c 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -445,10 +445,13 @@ void Graph::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substi
void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph)
{
// Look through each layer in the new subgraph and add any that are not already a member of this graph
- substituteSubgraph.ForEachLayer([this](Layer* layer)
+ substituteSubgraph.ForEachIConnectableLayer([this](IConnectableLayer* iConnectableLayer)
{
- if (std::find(std::begin(m_Layers), std::end(m_Layers), layer) == std::end(m_Layers))
+ if (std::find(std::begin(m_Layers),
+ std::end(m_Layers),
+ iConnectableLayer) == std::end(m_Layers))
{
+ auto layer = PolymorphicDowncast<Layer*>(iConnectableLayer);
layer->Reparent(*this, m_Layers.end());
m_LayersInOrder = false;
}
@@ -461,24 +464,26 @@ void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& subst
void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const SubgraphView& substituteSubgraph)
{
- ARMNN_ASSERT_MSG(!substituteSubgraph.GetLayers().empty(), "New sub-graph used for substitution must not be empty");
+ ARMNN_ASSERT_MSG(!substituteSubgraph.GetIConnectableLayers().empty(),
+ "New sub-graph used for substitution must not be empty");
- const SubgraphView::Layers& substituteSubgraphLayers = substituteSubgraph.GetLayers();
- std::for_each(substituteSubgraphLayers.begin(), substituteSubgraphLayers.end(), [&](Layer* layer)
+ const SubgraphView::IConnectableLayers& substituteSubgraphLayers = substituteSubgraph.GetIConnectableLayers();
+ std::for_each(substituteSubgraphLayers.begin(), substituteSubgraphLayers.end(), [&](IConnectableLayer* layer)
{
IgnoreUnused(layer);
+ layer = PolymorphicDowncast<Layer*>(layer);
ARMNN_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), layer) != m_Layers.end(),
"Substitute layer is not a member of graph");
});
- const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots();
- const SubgraphView::OutputSlots& subgraphOutputSlots = subgraph.GetOutputSlots();
+ const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots();
+ const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraph.GetIOutputSlots();
unsigned int subgraphNumInputSlots = armnn::numeric_cast<unsigned int>(subgraphInputSlots.size());
unsigned int subgraphNumOutputSlots = armnn::numeric_cast<unsigned int>(subgraphOutputSlots.size());
- const SubgraphView::InputSlots& substituteSubgraphInputSlots = substituteSubgraph.GetInputSlots();
- const SubgraphView::OutputSlots& substituteSubgraphOutputSlots = substituteSubgraph.GetOutputSlots();
+ const SubgraphView::IInputSlots& substituteSubgraphInputSlots = substituteSubgraph.GetIInputSlots();
+ const SubgraphView::IOutputSlots& substituteSubgraphOutputSlots = substituteSubgraph.GetIOutputSlots();
ARMNN_ASSERT(subgraphNumInputSlots == substituteSubgraphInputSlots.size());
ARMNN_ASSERT(subgraphNumOutputSlots == substituteSubgraphOutputSlots.size());
@@ -488,7 +493,7 @@ void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const Subgr
// Step 1: process input slots
for (unsigned int inputSlotIdx = 0; inputSlotIdx < subgraphNumInputSlots; ++inputSlotIdx)
{
- InputSlot* subgraphInputSlot = subgraphInputSlots.at(inputSlotIdx);
+ IInputSlot* subgraphInputSlot = subgraphInputSlots.at(inputSlotIdx);
ARMNN_ASSERT(subgraphInputSlot);
IOutputSlot* connectedOutputSlot = subgraphInputSlot->GetConnection();
@@ -503,19 +508,24 @@ void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const Subgr
// Step 2: process output slots
for(unsigned int outputSlotIdx = 0; outputSlotIdx < subgraphNumOutputSlots; ++outputSlotIdx)
{
- OutputSlot* subgraphOutputSlot = subgraphOutputSlots.at(outputSlotIdx);
+ auto subgraphOutputSlot =
+ PolymorphicDowncast<OutputSlot*>(subgraphOutputSlots.at(outputSlotIdx));
ARMNN_ASSERT(subgraphOutputSlot);
- OutputSlot* substituteOutputSlot = substituteSubgraphOutputSlots.at(outputSlotIdx);
+ auto substituteOutputSlot =
+ PolymorphicDowncast<OutputSlot*>(substituteSubgraphOutputSlots.at(outputSlotIdx));
ARMNN_ASSERT(substituteOutputSlot);
+
subgraphOutputSlot->MoveAllConnections(*substituteOutputSlot);
}
}
void Graph::EraseSubgraphLayers(SubgraphView &subgraph)
{
- for (auto layer : subgraph.GetLayers())
+
+ for (auto iConnectableLayer : subgraph.GetIConnectableLayers())
{
+ auto layer = PolymorphicDowncast<Layer*>(iConnectableLayer);
EraseLayer(layer);
}
subgraph.Clear();