diff options
Diffstat (limited to 'src/armnn/WorkingMemHandle.hpp')
-rw-r--r-- | src/armnn/WorkingMemHandle.hpp | 64 |
1 files changed, 44 insertions, 20 deletions
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp index 676d04288b..aaa9d593ee 100644 --- a/src/armnn/WorkingMemHandle.hpp +++ b/src/armnn/WorkingMemHandle.hpp @@ -26,18 +26,26 @@ class WorkingMemHandle final : public IWorkingMemHandle { public: - struct InputConnectionInfo + struct InputMemDescriptorCoords { LayerBindingId m_LayerBindingId; - unsigned int m_DescriptorIndex; - unsigned int m_InputIndex; + + std::vector<std::pair<unsigned int, unsigned int>> m_InputSlotCoords; + }; + + struct OutputMemDescriptorCoords + { + std::vector<LayerBindingId> m_LayerBindingIds; + + std::pair<unsigned int, unsigned int> m_OutputSlotCoords; + std::vector<std::pair<unsigned int, unsigned int>> m_InputSlotCoords; }; WorkingMemHandle(NetworkId networkId) : m_NetworkId(networkId){} WorkingMemHandle(NetworkId networkId, - std::vector<std::pair<LayerBindingId, LayerGuid>> inputHandles, - std::vector<InputConnectionInfo> inputConnections, + std::vector<InputMemDescriptorCoords> inputLayerInfo, + std::vector<OutputMemDescriptorCoords> ouputLayerInfo, std::vector<WorkingMemDescriptor> workingMemDescriptors, std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap, std::vector<std::shared_ptr<IMemoryManager>> memoryManagers, @@ -52,25 +60,19 @@ public: } /// Allocate the backing memory required for execution. If this is not called, then allocation will be - /// deferred to execution time. The mutex must be locked. + /// deferred to execution time. void Allocate() override; - /// Free the backing memory required for execution. The mutex must be locked. + /// Free the backing memory required for execution. void Free() override; - /// IsAllocated returns true if the backing memory is currently allocated. The mutex must be locked. + /// IsAllocated returns true if the backing memory is currently allocated. bool IsAllocated() override { return m_IsAllocated; } - /// Get a mutex which can be used for synchronizing access to the WorkingMemHandle object. - std::mutex& GetMutex() override - { - return m_Mutex; - } - - /// Get the WorkingMemDescriptor for a Layer. The mutex must be locked. + /// Get the WorkingMemDescriptor for a Layer. WorkingMemDescriptor& GetWorkingMemDescriptor(LayerGuid id) override { auto result = m_WorkingMemDescriptorMap.find(id); @@ -79,7 +81,7 @@ public: } /// Get the WorkingMemDescriptor at an index. The WorkingMemDescriptors are stored in the same order as - /// the Workloads in a topologically sorted graph. The mutex must be locked. + /// the Workloads in a topologically sorted graph. WorkingMemDescriptor& GetWorkingMemDescriptorAt(unsigned int id) override { return m_WorkingMemDescriptors[id]; @@ -90,22 +92,39 @@ public: return m_InputHandleMap.at(layerBindingId); }; + ITensorHandle* GetOutputHandle(LayerBindingId layerBindingId) const + { + return m_OutputHandleMap.at(layerBindingId); + }; + const std::vector<std::vector<ITensorHandle*>::iterator>& GetInputConnections(LayerBindingId layerBindingId) const { return m_InputConnectionMap.at(layerBindingId); }; - std::unordered_map<LayerBindingId, bool> GetValidationMap() const + const std::vector<std::vector<ITensorHandle*>::iterator>& GetOutputConnection(LayerBindingId layerBindingId) const + { + return m_OutputConnectionMap.at(layerBindingId); + }; + + void MemSyncOutputs(); + + std::vector<LayerBindingId>& GetBindingIdVector() { - return m_ValidationMap; + return m_BindingIdVec; }; + void ValidateBindingIds(); + private: + using DifferenceType = std::vector<ITensorHandle*>::difference_type; NetworkId m_NetworkId; std::shared_ptr<ProfilerImpl> m_Profiler; std::unordered_map<LayerBindingId, ITensorHandle*> m_InputHandleMap; + std::unordered_map<LayerBindingId, ITensorHandle*> m_OutputHandleMap; std::unordered_map<LayerBindingId, std::vector<std::vector<ITensorHandle*>::iterator>> m_InputConnectionMap; + std::unordered_map<LayerBindingId, std::vector<std::vector<ITensorHandle*>::iterator>> m_OutputConnectionMap; std::vector<WorkingMemDescriptor> m_WorkingMemDescriptors; std::unordered_map<LayerGuid, WorkingMemDescriptor> m_WorkingMemDescriptorMap; @@ -116,9 +135,14 @@ private: // constant tensor's can be shared by multiple WorkingMemHandles and so will not be stored here std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > m_OwnedTensorHandles; - std::unordered_map<LayerBindingId, bool> m_ValidationMap; + std::unordered_map<LayerBindingId, bool> m_InputValidationMap; + std::unordered_map<LayerBindingId, bool> m_OutputValidationMap; + + std::vector<LayerBindingId> m_BindingIdVec; + + DifferenceType m_InputSize; + bool m_IsAllocated; - std::mutex m_Mutex; }; } // end experimental namespace |