diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index 4d9c13818b..e837a08aa3 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -9,6 +9,8 @@ #include "armnnDeserializer/IDeserializer.hpp" #include <ArmnnSchema_generated.h> +#include <unordered_map> + namespace armnnDeserializer { class Deserializer : public IDeserializer @@ -100,8 +102,8 @@ private: void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex); void ParseSubtraction(GraphPtr graph, unsigned int layerIndex); - void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot); - void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot); + void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot); + void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot); void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex, armnn::IConnectableLayer* layer); void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex, @@ -120,17 +122,14 @@ private: std::vector<NameToBindingInfo> m_OutputBindings; /// A mapping of an output slot to each of the input slots it should be connected to - /// The outputSlot is from the layer that creates this tensor as one of its outputs - /// The inputSlots are from the layers that use this tensor as one of their inputs - struct Slots + struct SlotsMap { - armnn::IOutputSlot* outputSlot; - std::vector<armnn::IInputSlot*> inputSlots; - - Slots() : outputSlot(nullptr) { } + std::vector<armnn::IOutputSlot*> outputSlots; + std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots; }; - typedef std::vector<Slots> Connection; - std::vector<Connection> m_GraphConnections; + + typedef std::vector<SlotsMap> Connections; + std::vector<Connections> m_GraphConnections; }; } //namespace armnnDeserializer |