aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-03-04 17:10:40 +0000
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>2019-03-07 10:34:52 +0000
commitd469faf863f4ecd3ba56f27e51884ef0dfeac7bf (patch)
tree8492bba3bdaccf140ca8ac7e25039a6f110a4a13 /src/armnnDeserializer
parentac97c8cda28f81ce76834b8b769967d42b02e2ac (diff)
downloadarmnn-d469faf863f4ecd3ba56f27e51884ef0dfeac7bf.tar.gz
IVGCVSW-2783 Fix Deserializer connections for layer with multiple outputs
Change-Id: Icb278dfd8900334665432963fa6f6341a461ef3b Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Diffstat (limited to 'src/armnnDeserializer')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp59
-rw-r--r--src/armnnDeserializer/Deserializer.hpp21
2 files changed, 38 insertions, 42 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 96879bb65a..ed110ad750 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -599,16 +599,17 @@ INetworkPtr Deserializer::CreateNetworkFromGraph(GraphPtr graph)
SetupOutputLayers(graph);
// establish the connections from the layer outputs to the inputs of the subsequent layers
- for (size_t connectionIndex = 0; connectionIndex < m_GraphConnections[0].size(); ++connectionIndex)
+ for (size_t connectionsIndex = 0; connectionsIndex < m_GraphConnections[0].size(); ++connectionsIndex)
{
- if (m_GraphConnections[0][connectionIndex].outputSlot != nullptr)
+ SlotsMap& slotsMap = m_GraphConnections[0][connectionsIndex];
+ for (unsigned int outputSlotIndex = 0; outputSlotIndex < slotsMap.outputSlots.size(); outputSlotIndex++)
{
- for (size_t inputSlotIdx = 0;
- inputSlotIdx < m_GraphConnections[0][connectionIndex].inputSlots.size();
- ++inputSlotIdx)
+ if (slotsMap.inputSlots.find(outputSlotIndex) != slotsMap.inputSlots.end())
{
- m_GraphConnections[0][connectionIndex].outputSlot->Connect(
- *(m_GraphConnections[0][connectionIndex].inputSlots[inputSlotIdx]));
+ for (armnn::IInputSlot* inputSlot : slotsMap.inputSlots[outputSlotIndex])
+ {
+ slotsMap.outputSlots[outputSlotIndex]->Connect(*inputSlot);
+ }
}
}
}
@@ -743,39 +744,35 @@ void Deserializer::RegisterInputSlots(GraphPtr graph,
for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
{
+ auto fbConnection = parsedLayer->inputSlots()->Get(slotIndex)->connection();
armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
- uint32_t sourceLayerIndex = parsedLayer->inputSlots()->Get(slotIndex)->connection()->sourceLayerIndex();
- RegisterInputSlotOfConnection(sourceLayerIndex, slot);
- }
-}
-
-void Deserializer::RegisterInputSlotOfConnection(uint32_t connectionIndex,
- armnn::IInputSlot* slot)
-{
- BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
- Slots& slots = m_GraphConnections[0][connectionIndex];
- slots.inputSlots.push_back(slot);
+ RegisterInputSlotOfConnection(fbConnection->sourceLayerIndex(), fbConnection->outputSlotIndex(), slot);
+ }
}
-void Deserializer::RegisterOutputSlotOfConnection(uint32_t connectionIndex,
- armnn::IOutputSlot* slot)
+void Deserializer::RegisterInputSlotOfConnection(uint32_t sourceLayerIndex,
+ uint32_t outputSlotIndex,
+ armnn::IInputSlot* slot)
{
- BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex);
-
- Slots& slots = m_GraphConnections[0][connectionIndex];
+ BOOST_ASSERT(m_GraphConnections[0].size() > sourceLayerIndex);
- // assuming there is only one producer for that tensor
- if (slots.outputSlot != nullptr)
+ SlotsMap& slotsMap = m_GraphConnections[0][sourceLayerIndex];
+ if (slotsMap.inputSlots.find(outputSlotIndex) == slotsMap.inputSlots.end())
{
- throw ParseException(boost::str(
- boost::format("Another layer has already registered itself as the producer of "
- "connection:%1% / %2%") %
- connectionIndex %
- CHECK_LOCATION().AsString()));
+ slotsMap.inputSlots[outputSlotIndex] = {slot};
+ }
+ else
+ {
+ slotsMap.inputSlots[outputSlotIndex].push_back(slot);
}
+}
- slots.outputSlot = slot;
+void Deserializer::RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex,
+ armnn::IOutputSlot* slot)
+{
+ BOOST_ASSERT(m_GraphConnections[0].size() > sourceLayerIndex);
+ m_GraphConnections[0][sourceLayerIndex].outputSlots.push_back(slot);
}
void Deserializer::ParseActivation(GraphPtr graph, unsigned int layerIndex)
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 4d9c13818b..e837a08aa3 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -9,6 +9,8 @@
#include "armnnDeserializer/IDeserializer.hpp"
#include <ArmnnSchema_generated.h>
+#include <unordered_map>
+
namespace armnnDeserializer
{
class Deserializer : public IDeserializer
@@ -100,8 +102,8 @@ private:
void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
- void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
- void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot);
+ void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot);
+ void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
armnn::IConnectableLayer* layer);
void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex,
@@ -120,17 +122,14 @@ private:
std::vector<NameToBindingInfo> m_OutputBindings;
/// A mapping of an output slot to each of the input slots it should be connected to
- /// The outputSlot is from the layer that creates this tensor as one of its outputs
- /// The inputSlots are from the layers that use this tensor as one of their inputs
- struct Slots
+ struct SlotsMap
{
- armnn::IOutputSlot* outputSlot;
- std::vector<armnn::IInputSlot*> inputSlots;
-
- Slots() : outputSlot(nullptr) { }
+ std::vector<armnn::IOutputSlot*> outputSlots;
+ std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots;
};
- typedef std::vector<Slots> Connection;
- std::vector<Connection> m_GraphConnections;
+
+ typedef std::vector<SlotsMap> Connections;
+ std::vector<Connections> m_GraphConnections;
};
} //namespace armnnDeserializer