diff options
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r-- | src/armnn/Graph.cpp | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index e521623737..9e00f5ec01 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -285,12 +285,13 @@ void Graph::AddCopyLayers(std::map<BackendId, std::unique_ptr<IBackendInternal>> { OutputSlot& srcOutputSlot = srcLayer->GetOutputSlot(srcOutputIndex); const std::vector<InputSlot*> srcConnections = srcOutputSlot.GetConnections(); + const std::vector<MemoryStrategy> srcMemoryStrategies = srcOutputSlot.GetMemoryStrategies(); for (unsigned int srcConnectionIndex = 0; srcConnectionIndex < srcConnections.size(); srcConnectionIndex++) { InputSlot* dstInputSlot = srcConnections[srcConnectionIndex]; BOOST_ASSERT(dstInputSlot); - auto strategy = srcOutputSlot.GetMemoryStrategyForConnection(srcConnectionIndex); + MemoryStrategy strategy = srcMemoryStrategies[srcConnectionIndex]; BOOST_ASSERT_MSG(strategy != MemoryStrategy::Undefined, "Undefined memory strategy found while adding copy layers for compatibility"); @@ -339,8 +340,19 @@ void Graph::AddCopyLayers(std::map<BackendId, std::unique_ptr<IBackendInternal>> copyOutputSlot.SetTensorHandleFactory(ITensorHandleFactory::LegacyFactoryId); } + // The output strategy of a copy layer is always DirectCompatibility. copyOutputSlot.SetMemoryStrategy(0, MemoryStrategy::DirectCompatibility); - srcOutputSlot.SetMemoryStrategy(srcConnectionIndex, MemoryStrategy::DirectCompatibility); + + // Recalculate the connection index on the previous layer as we have just inserted into it. + const std::vector<InputSlot*>& newSourceConnections = srcOutputSlot.GetConnections(); + long newSrcConnectionIndex = std::distance(newSourceConnections.begin(), + std::find(newSourceConnections.begin(), + newSourceConnections.end(), + ©Layer->GetInputSlot(0))); + + // The input strategy of a copy layer is always DirectCompatibilty. + srcOutputSlot.SetMemoryStrategy(boost::numeric_cast<unsigned int>(newSrcConnectionIndex), + MemoryStrategy::DirectCompatibility); } } } |