diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 25 |
1 files changed, 16 insertions, 9 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index c647ac3639..38a6b602fc 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -51,8 +51,6 @@ public: static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len); static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex); static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex); - static LayerBaseRawPtrVector GetGraphInputs(const GraphPtr& graphPtr); - static LayerBaseRawPtrVector GetGraphOutputs(const GraphPtr& graphPtr); static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex); static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex); static std::string GetLayerName(const GraphPtr& graph, unsigned int index); @@ -116,17 +114,23 @@ private: void ParseSubtraction(GraphPtr graph, unsigned int layerIndex); void ParseSwitch(GraphPtr graph, unsigned int layerIndex); - void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot); - void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot); void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex, armnn::IConnectableLayer* layer); void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex, armnn::IConnectableLayer* layer); + + // NOTE index here must be from flatbuffer object index property + void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IOutputSlot* slot); + void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot); + void ResetParser(); void SetupInputLayers(GraphPtr graphPtr); void SetupOutputLayers(GraphPtr graphPtr); + /// Helper to get the index of the layer in the flatbuffer vector from its index property + unsigned int GetLayerIndexInVector(GraphPtr graph, unsigned int index); + /// The network we're building. Gets cleared after it is passed to the user armnn::INetworkPtr m_Network; std::vector<LayerParsingFunction> m_ParserFunctions; @@ -135,15 +139,18 @@ private: std::vector<NameToBindingInfo> m_InputBindings; std::vector<NameToBindingInfo> m_OutputBindings; - /// A mapping of an output slot to each of the input slots it should be connected to - struct SlotsMap + /// This struct describe connections for each layer + struct Connections { - std::vector<armnn::IOutputSlot*> outputSlots; + // Maps output slot index (property in flatbuffer object) to IOutputSlot pointer + std::unordered_map<unsigned int, armnn::IOutputSlot*> outputSlots; + + // Maps output slot index to IInputSlot pointer the output slot should be connected to std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots; }; - typedef std::vector<SlotsMap> Connections; - std::vector<Connections> m_GraphConnections; + /// Maps layer index (index property in flatbuffer object) to Connections for each layer + std::unordered_map<unsigned int, Connections> m_GraphConnections; }; } //namespace armnnDeserializer |