aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.cpp
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-03-04 17:10:40 +0000
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>2019-03-07 10:34:52 +0000
commitd469faf863f4ecd3ba56f27e51884ef0dfeac7bf (patch)
tree8492bba3bdaccf140ca8ac7e25039a6f110a4a13 /src/armnnDeserializer/Deserializer.cpp
parentac97c8cda28f81ce76834b8b769967d42b02e2ac (diff)
downloadarmnn-d469faf863f4ecd3ba56f27e51884ef0dfeac7bf.tar.gz
IVGCVSW-2783 Fix Deserializer connections for layer with multiple outputs
Change-Id: Icb278dfd8900334665432963fa6f6341a461ef3b Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp59
1 files changed, 28 insertions, 31 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 96879bb65a..ed110ad750 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -599,16 +599,17 @@ INetworkPtr Deserializer::CreateNetworkFromGraph(GraphPtr graph)
SetupOutputLayers(graph);
// establish the connections from the layer outputs to the inputs of the subsequent layers
- for (size_t connectionIndex = 0; connectionIndex < m_GraphConnections[0].size(); ++connectionIndex)
+ for (size_t connectionsIndex = 0; connectionsIndex < m_GraphConnections[0].size(); ++connectionsIndex)
{
- if (m_GraphConnections[0][connectionIndex].outputSlot != nullptr)
+ SlotsMap& slotsMap = m_GraphConnections[0][connectionsIndex];
+ for (unsigned int outputSlotIndex = 0; outputSlotIndex < slotsMap.outputSlots.size(); outputSlotIndex++)
{
- for (size_t inputSlotIdx = 0;
- inputSlotIdx < m_GraphConnections[0][connectionIndex].inputSlots.size();
- ++inputSlotIdx)
+ if (slotsMap.inputSlots.find(outputSlotIndex) != slotsMap.inputSlots.end())
{
- m_GraphConnections[0][connectionIndex].outputSlot->Connect(
- *(m_GraphConnections[0][connectionIndex].inputSlots[inputSlotIdx]));
+ for (armnn::IInputSlot* inputSlot : slotsMap.inputSlots[outputSlotIndex])
+ {
+ slotsMap.outputSlots[outputSlotIndex]->Connect(*inputSlot);
+ }
}
}
}
@@ -743,39 +744,35 @@ void Deserializer::RegisterInputSlots(GraphPtr graph,
for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
{
+ auto fbConnection = parsedLayer->inputSlots()->Get(slotIndex)->connection();
armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
- uint32_t sourceLayerIndex = parsedLayer->inputSlots()->Get(slotIndex)->connection()->sourceLayerIndex();
- RegisterInputSlotOfConnection(sourceLayerIndex, slot);
- }
-}
-
-void Deserializer::RegisterInputSlotOfConnection(uint32_t connectionIndex,
- armnn::IInputSlot* slot)
-{
- BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
- Slots& slots = m_GraphConnections[0][connectionIndex];
- slots.inputSlots.push_back(slot);
+ RegisterInputSlotOfConnection(fbConnection->sourceLayerIndex(), fbConnection->outputSlotIndex(), slot);
+ }
}
-void Deserializer::RegisterOutputSlotOfConnection(uint32_t connectionIndex,
- armnn::IOutputSlot* slot)
+void Deserializer::RegisterInputSlotOfConnection(uint32_t sourceLayerIndex,
+ uint32_t outputSlotIndex,
+ armnn::IInputSlot* slot)
{
- BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
-
- Slots& slots = m_GraphConnections[0][connectionIndex];
+ BOOST_ASSERT(m_GraphConnections[0].size() > sourceLayerIndex);
- // assuming there is only one producer for that tensor
- if (slots.outputSlot != nullptr)
+ SlotsMap& slotsMap = m_GraphConnections[0][sourceLayerIndex];
+ if (slotsMap.inputSlots.find(outputSlotIndex) == slotsMap.inputSlots.end())
{
- throw ParseException(boost::str(
- boost::format("Another layer has already registered itself as the producer of "
- "connection:%1% / %2%") %
- connectionIndex %
- CHECK_LOCATION().AsString()));
+ slotsMap.inputSlots[outputSlotIndex] = {slot};
+ }
+ else
+ {
+ slotsMap.inputSlots[outputSlotIndex].push_back(slot);
}
+}
- slots.outputSlot = slot;
+void Deserializer::RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex,
+ armnn::IOutputSlot* slot)
+{
+ BOOST_ASSERT(m_GraphConnections[0].size() > sourceLayerIndex);
+ m_GraphConnections[0][sourceLayerIndex].outputSlots.push_back(slot);
}
void Deserializer::ParseActivation(GraphPtr graph, unsigned int layerIndex)