aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/LoadedNetwork.cpp19
1 files changed, 16 insertions, 3 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 4688b6eea4..d6dd5d2ee8 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -1547,7 +1547,20 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network
auto found = m_ConstantTensorHandles.find(key);
if (found != m_ConstantTensorHandles.end())
{
- workingMemDescriptor.m_Inputs.push_back(found->second);
+ ITensorHandle* tensorHandle = found->second;
+ workingMemDescriptor.m_Inputs.push_back(tensorHandle);
+
+ // Odd case where a constant layer is connected to an output layer
+ // We will need to create a HandleInfo to track it
+ if (isOutputLayer)
+ {
+ LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
+
+ HandleInfo& handleInfo = handleReferenceCounts[tensorHandle];
+ handleInfo.isOutputLayerHandle = true;
+ handleInfo.m_OutputMemDescriptorCoords.m_LayerBindingIds.push_back(bindingId);
+ handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, 0});
+ }
continue;
}
@@ -1563,9 +1576,9 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network
{
LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
handleInfo.m_OutputMemDescriptorCoords.m_LayerBindingIds.push_back(bindingId);
- handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, slot.GetSlotIndex()});
+ handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, 0});
}
- // In this case the layer is not an Output Layer but shares it's input tensorhandle with an OutputLayer
+ // In this case the layer is not an Output Layer but shares its input tensorhandle with an OutputLayer
// It will need to be updated as well, if we swap out the tensorhandle
else if (handleInfo.isOutputLayerHandle)
{