diff options
Diffstat (limited to 'src/armnn/SubgraphView.cpp')
-rw-r--r-- | src/armnn/SubgraphView.cpp | 91 |
1 files changed, 34 insertions, 57 deletions
diff --git a/src/armnn/SubgraphView.cpp b/src/armnn/SubgraphView.cpp index 5f972a9767..804ff731fb 100644 --- a/src/armnn/SubgraphView.cpp +++ b/src/armnn/SubgraphView.cpp @@ -416,7 +416,7 @@ public: }; -SubgraphView SubgraphView::GetWorkingCopy() +SubgraphView SubgraphView::GetWorkingCopy() const { if (p_WorkingCopyImpl) { @@ -426,79 +426,63 @@ SubgraphView SubgraphView::GetWorkingCopy() // Create a cut down SubgraphView with underlying graph containing only the relevant layers. // It needs its own underlying layers so that they can be replaced safely. - Graph newGraph = Graph(); + auto ptr = std::make_shared<SubgraphViewWorkingCopy>(Graph()); + std::unordered_map<const IConnectableLayer*, IConnectableLayer*> originalToClonedLayerMap; std::list<armnn::IConnectableLayer*> originalSubgraphLayers = GetIConnectableLayers(); - auto ptr = std::make_shared<SubgraphViewWorkingCopy>(std::move(newGraph)); - SubgraphView::IInputSlots workingCopyInputs; - for (auto&& originalLayer : originalSubgraphLayers) { Layer* const layer = PolymorphicDowncast<const Layer*>(originalLayer)->Clone(ptr->m_Graph); originalToClonedLayerMap.emplace(originalLayer, layer); } + SubgraphView::IInputSlots workingCopyInputs; // Add IInputSlots to workingCopy - std::vector<const IConnectableLayer*> processed; for (auto originalSubgraphInputSlot : GetIInputSlots()) { const IConnectableLayer& originalSubgraphLayer = PolymorphicDowncast<InputSlot*>(originalSubgraphInputSlot)->GetOwningLayer(); - // Only need process Slots of layer once - if (std::find(processed.begin(), processed.end(), &originalSubgraphLayer) == processed.end()) - { - IConnectableLayer* clonedLayer = originalToClonedLayerMap[&originalSubgraphLayer]; + auto* clonedLayer = originalToClonedLayerMap[&originalSubgraphLayer]; - // Add the InputSlot to WorkingCopy InputSlots - for (unsigned int i = 0; i < clonedLayer->GetNumInputSlots(); i++) - { - workingCopyInputs.push_back(&clonedLayer->GetInputSlot(i)); - } - processed.push_back(&originalSubgraphLayer); - } + workingCopyInputs.push_back(&clonedLayer->GetInputSlot(originalSubgraphInputSlot->GetSlotIndex())); } - // Empty processed - processed.clear(); for (auto originalSubgraphLayer : originalSubgraphLayers) { IConnectableLayer* const clonedLayer = originalToClonedLayerMap[originalSubgraphLayer]; - // connect all cloned layers as per original subgraph - for (unsigned int i = 0; i < clonedLayer->GetNumOutputSlots(); i++) + // OutputLayers have no OutputSlots to be connected + if (clonedLayer->GetType() != LayerType::Output) { - // OutputLayers have no OutputSlots to be connected - if (clonedLayer->GetType() != LayerType::Output) + // connect all cloned layers as per original subgraph + for (unsigned int i = 0; i < clonedLayer->GetNumOutputSlots(); i++) { - auto& outputSlot = clonedLayer->GetOutputSlot(i); - for (unsigned int k = 0; k < originalSubgraphLayer->GetNumOutputSlots(); k++) + auto& originalOutputSlot = originalSubgraphLayer->GetOutputSlot(i); + auto& clonedOutputSlot = clonedLayer->GetOutputSlot(i); + for (unsigned int j = 0; j < originalOutputSlot.GetNumConnections(); j++) { - auto& originalOutputSlot = originalSubgraphLayer->GetOutputSlot(k); - for (unsigned int j = 0; j < originalOutputSlot.GetNumConnections(); j++) + // nextLayer is the layer with IInputSlot connected to IOutputSlot we are working on + const IConnectableLayer& nextLayerOnOriginalSubgraph = + originalOutputSlot.GetConnection(j)->GetOwningIConnectableLayer(); + + // Check the layer is in our map and so has a clonedLayer + if (originalToClonedLayerMap.find(&nextLayerOnOriginalSubgraph) != originalToClonedLayerMap.end()) { - // nextLayer is the layer with IInputSlot connected to IOutputSlot we are working on - const IConnectableLayer& nextLayer = - originalOutputSlot.GetConnection(j)->GetOwningIConnectableLayer(); - - // Check the layer is in our map and so has a clonedLayer - if (originalToClonedLayerMap.find(&nextLayer) != originalToClonedLayerMap.end()) - { - IConnectableLayer* newGraphTargetLayer = originalToClonedLayerMap[&nextLayer]; - - IInputSlot& inputSlot = - newGraphTargetLayer->GetInputSlot( - PolymorphicDowncast<OutputSlot*>( - &originalOutputSlot)->GetConnection(j)->GetSlotIndex()); - - // Then make the connection - outputSlot.Connect(inputSlot); - } + auto* nextLayerOnClonedSubgraph = originalToClonedLayerMap[&nextLayerOnOriginalSubgraph]; + + auto index = PolymorphicDowncast<OutputSlot*>( + &originalOutputSlot)->GetConnection(j)->GetSlotIndex(); + + IInputSlot& inputSlot = nextLayerOnClonedSubgraph->GetInputSlot(index); + + // Then make the connection + clonedOutputSlot.Connect(inputSlot); } - // Copy the tensorInfo to the clonedOutputSlot - outputSlot.SetTensorInfo(originalOutputSlot.GetTensorInfo()); } + // Copy the tensorInfo to the clonedOutputSlot + clonedOutputSlot.SetTensorInfo(originalOutputSlot.GetTensorInfo()); } } } @@ -508,25 +492,18 @@ SubgraphView SubgraphView::GetWorkingCopy() // Add IOutputSlots to workingCopy for (auto outputSlot : GetIOutputSlots()) { - + auto outputSlotIndex = outputSlot->CalculateIndexOnOwner(); const IConnectableLayer& originalSubgraphLayer = outputSlot->GetOwningIConnectableLayer(); // OutputLayers have no OutputSlots to be connected - // Only need process Slots of layer once - if (originalSubgraphLayer.GetType() != LayerType::Output && - std::find(processed.begin(), processed.end(), &originalSubgraphLayer) == processed.end()) + if (originalSubgraphLayer.GetType() != LayerType::Output) { IConnectableLayer* clonedLayer = originalToClonedLayerMap[&originalSubgraphLayer]; - // Add the OutputSlot to WorkingCopy InputSlots - for (unsigned int i = 0; i < clonedLayer->GetNumOutputSlots(); i++) - { - workingCopyOutputs.push_back(&clonedLayer->GetOutputSlot(i)); - } - processed.push_back(&originalSubgraphLayer); + // Add the OutputSlot of clonedLayer to WorkingCopy OutputSlots + workingCopyOutputs.push_back(&clonedLayer->GetOutputSlot(outputSlotIndex)); } } - processed.clear(); SubgraphView::IConnectableLayers workingCopyLayers; for (auto& pair : originalToClonedLayerMap) |