diff options
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 59 | ||||
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 21 |
2 files changed, 38 insertions, 42 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 96879bb65a..ed110ad750 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -599,16 +599,17 @@ INetworkPtr Deserializer::CreateNetworkFromGraph(GraphPtr graph) SetupOutputLayers(graph); // establish the connections from the layer outputs to the inputs of the subsequent layers - for (size_t connectionIndex = 0; connectionIndex < m_GraphConnections[0].size(); ++connectionIndex) + for (size_t connectionsIndex = 0; connectionsIndex < m_GraphConnections[0].size(); ++connectionsIndex) { - if (m_GraphConnections[0][connectionIndex].outputSlot != nullptr) + SlotsMap& slotsMap = m_GraphConnections[0][connectionsIndex]; + for (unsigned int outputSlotIndex = 0; outputSlotIndex < slotsMap.outputSlots.size(); outputSlotIndex++) { - for (size_t inputSlotIdx = 0; - inputSlotIdx < m_GraphConnections[0][connectionIndex].inputSlots.size(); - ++inputSlotIdx) + if (slotsMap.inputSlots.find(outputSlotIndex) != slotsMap.inputSlots.end()) { - m_GraphConnections[0][connectionIndex].outputSlot->Connect( - *(m_GraphConnections[0][connectionIndex].inputSlots[inputSlotIdx])); + for (armnn::IInputSlot* inputSlot : slotsMap.inputSlots[outputSlotIndex]) + { + slotsMap.outputSlots[outputSlotIndex]->Connect(*inputSlot); + } } } } @@ -743,39 +744,35 @@ void Deserializer::RegisterInputSlots(GraphPtr graph, for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex) { + auto fbConnection = parsedLayer->inputSlots()->Get(slotIndex)->connection(); armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex)); - uint32_t sourceLayerIndex = parsedLayer->inputSlots()->Get(slotIndex)->connection()->sourceLayerIndex(); - RegisterInputSlotOfConnection(sourceLayerIndex, slot); - } -} - -void Deserializer::RegisterInputSlotOfConnection(uint32_t connectionIndex, - armnn::IInputSlot* slot) -{ - BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex); - Slots& slots = m_GraphConnections[0][connectionIndex]; - slots.inputSlots.push_back(slot); + RegisterInputSlotOfConnection(fbConnection->sourceLayerIndex(), fbConnection->outputSlotIndex(), slot); + } } -void Deserializer::RegisterOutputSlotOfConnection(uint32_t connectionIndex, - armnn::IOutputSlot* slot) +void Deserializer::RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, + uint32_t outputSlotIndex, + armnn::IInputSlot* slot) { - BOOST_ASSERT(m_GraphConnections[0].size() > connectionIndex); - - Slots& slots = m_GraphConnections[0][connectionIndex]; + BOOST_ASSERT(m_GraphConnections[0].size() > sourceLayerIndex); - // assuming there is only one producer for that tensor - if (slots.outputSlot != nullptr) + SlotsMap& slotsMap = m_GraphConnections[0][sourceLayerIndex]; + if (slotsMap.inputSlots.find(outputSlotIndex) == slotsMap.inputSlots.end()) { - throw ParseException(boost::str( - boost::format("Another layer has already registered itself as the producer of " - "connection:%1% / %2%") % - connectionIndex % - CHECK_LOCATION().AsString())); + slotsMap.inputSlots[outputSlotIndex] = {slot}; + } + else + { + slotsMap.inputSlots[outputSlotIndex].push_back(slot); } +} - slots.outputSlot = slot; +void Deserializer::RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, + armnn::IOutputSlot* slot) +{ + BOOST_ASSERT(m_GraphConnections[0].size() > sourceLayerIndex); + m_GraphConnections[0][sourceLayerIndex].outputSlots.push_back(slot); } void Deserializer::ParseActivation(GraphPtr graph, unsigned int layerIndex) diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index 4d9c13818b..e837a08aa3 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -9,6 +9,8 @@ #include "armnnDeserializer/IDeserializer.hpp" #include <ArmnnSchema_generated.h> +#include <unordered_map> + namespace armnnDeserializer { class Deserializer : public IDeserializer @@ -100,8 +102,8 @@ private: void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex); void ParseSubtraction(GraphPtr graph, unsigned int layerIndex); - void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot); - void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot); + void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot); + void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot); void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex, armnn::IConnectableLayer* layer); void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex, @@ -120,17 +122,14 @@ private: std::vector<NameToBindingInfo> m_OutputBindings; /// A mapping of an output slot to each of the input slots it should be connected to - /// The outputSlot is from the layer that creates this tensor as one of its outputs - /// The inputSlots are from the layers that use this tensor as one of their inputs - struct Slots + struct SlotsMap { - armnn::IOutputSlot* outputSlot; - std::vector<armnn::IInputSlot*> inputSlots; - - Slots() : outputSlot(nullptr) { } + std::vector<armnn::IOutputSlot*> outputSlots; + std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots; }; - typedef std::vector<Slots> Connection; - std::vector<Connection> m_GraphConnections; + + typedef std::vector<SlotsMap> Connections; + std::vector<Connections> m_GraphConnections; }; } //namespace armnnDeserializer |