aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer/Deserializer.hpp')
-rw-r--r--src/armnnDeserializer/Deserializer.hpp21
1 files changed, 10 insertions, 11 deletions
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index 4d9c13818b..e837a08aa3 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -9,6 +9,8 @@
#include "armnnDeserializer/IDeserializer.hpp"
#include <ArmnnSchema_generated.h>
+#include <unordered_map>
+
namespace armnnDeserializer
{
class Deserializer : public IDeserializer
@@ -100,8 +102,8 @@ private:
void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
- void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
- void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot);
+ void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot);
+ void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
armnn::IConnectableLayer* layer);
void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex,
@@ -120,17 +122,14 @@ private:
std::vector<NameToBindingInfo> m_OutputBindings;
/// A mapping of an output slot to each of the input slots it should be connected to
- /// The outputSlot is from the layer that creates this tensor as one of its outputs
- /// The inputSlots are from the layers that use this tensor as one of their inputs
- struct Slots
+ struct SlotsMap
{
- armnn::IOutputSlot* outputSlot;
- std::vector<armnn::IInputSlot*> inputSlots;
-
- Slots() : outputSlot(nullptr) { }
+ std::vector<armnn::IOutputSlot*> outputSlots;
+ std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots;
};
- typedef std::vector<Slots> Connection;
- std::vector<Connection> m_GraphConnections;
+
+ typedef std::vector<SlotsMap> Connections;
+ std::vector<Connections> m_GraphConnections;
};
} //namespace armnnDeserializer