aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r--src/armnn/Graph.cpp16
1 files changed, 14 insertions, 2 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index e521623737..9e00f5ec01 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -285,12 +285,13 @@ void Graph::AddCopyLayers(std::map<BackendId, std::unique_ptr<IBackendInternal>>
{
OutputSlot& srcOutputSlot = srcLayer->GetOutputSlot(srcOutputIndex);
const std::vector<InputSlot*> srcConnections = srcOutputSlot.GetConnections();
+ const std::vector<MemoryStrategy> srcMemoryStrategies = srcOutputSlot.GetMemoryStrategies();
for (unsigned int srcConnectionIndex = 0; srcConnectionIndex < srcConnections.size(); srcConnectionIndex++)
{
InputSlot* dstInputSlot = srcConnections[srcConnectionIndex];
BOOST_ASSERT(dstInputSlot);
- auto strategy = srcOutputSlot.GetMemoryStrategyForConnection(srcConnectionIndex);
+ MemoryStrategy strategy = srcMemoryStrategies[srcConnectionIndex];
BOOST_ASSERT_MSG(strategy != MemoryStrategy::Undefined,
"Undefined memory strategy found while adding copy layers for compatibility");
@@ -339,8 +340,19 @@ void Graph::AddCopyLayers(std::map<BackendId, std::unique_ptr<IBackendInternal>>
copyOutputSlot.SetTensorHandleFactory(ITensorHandleFactory::LegacyFactoryId);
}
+ // The output strategy of a copy layer is always DirectCompatibility.
copyOutputSlot.SetMemoryStrategy(0, MemoryStrategy::DirectCompatibility);
- srcOutputSlot.SetMemoryStrategy(srcConnectionIndex, MemoryStrategy::DirectCompatibility);
+
+ // Recalculate the connection index on the previous layer as we have just inserted into it.
+ const std::vector<InputSlot*>& newSourceConnections = srcOutputSlot.GetConnections();
+ long newSrcConnectionIndex = std::distance(newSourceConnections.begin(),
+ std::find(newSourceConnections.begin(),
+ newSourceConnections.end(),
+ &copyLayer->GetInputSlot(0)));
+
+ // The input strategy of a copy layer is always DirectCompatibilty.
+ srcOutputSlot.SetMemoryStrategy(boost::numeric_cast<unsigned int>(newSrcConnectionIndex),
+ MemoryStrategy::DirectCompatibility);
}
}
}