diff options
Diffstat (limited to 'src/armnn/WorkingMemHandle.hpp')
-rw-r--r-- | src/armnn/WorkingMemHandle.hpp | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp index 5e3fd66299..676d04288b 100644 --- a/src/armnn/WorkingMemHandle.hpp +++ b/src/armnn/WorkingMemHandle.hpp @@ -21,11 +21,23 @@ namespace armnn namespace experimental { + class WorkingMemHandle final : public IWorkingMemHandle { public: + struct InputConnectionInfo + { + LayerBindingId m_LayerBindingId; + unsigned int m_DescriptorIndex; + unsigned int m_InputIndex; + }; + + WorkingMemHandle(NetworkId networkId) : m_NetworkId(networkId){} + WorkingMemHandle(NetworkId networkId, + std::vector<std::pair<LayerBindingId, LayerGuid>> inputHandles, + std::vector<InputConnectionInfo> inputConnections, std::vector<WorkingMemDescriptor> workingMemDescriptors, std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap, std::vector<std::shared_ptr<IMemoryManager>> memoryManagers, @@ -39,8 +51,6 @@ public: return m_NetworkId; } - - /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. void Allocate() override; @@ -75,10 +85,28 @@ public: return m_WorkingMemDescriptors[id]; } + ITensorHandle* GetInputHandle(LayerBindingId layerBindingId) const + { + return m_InputHandleMap.at(layerBindingId); + }; + + const std::vector<std::vector<ITensorHandle*>::iterator>& GetInputConnections(LayerBindingId layerBindingId) const + { + return m_InputConnectionMap.at(layerBindingId); + }; + + std::unordered_map<LayerBindingId, bool> GetValidationMap() const + { + return m_ValidationMap; + }; + private: NetworkId m_NetworkId; std::shared_ptr<ProfilerImpl> m_Profiler; + std::unordered_map<LayerBindingId, ITensorHandle*> m_InputHandleMap; + std::unordered_map<LayerBindingId, std::vector<std::vector<ITensorHandle*>::iterator>> m_InputConnectionMap; + std::vector<WorkingMemDescriptor> m_WorkingMemDescriptors; std::unordered_map<LayerGuid, WorkingMemDescriptor> m_WorkingMemDescriptorMap; @@ -88,6 +116,7 @@ private: // constant tensor's can be shared by multiple WorkingMemHandles and so will not be stored here std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > m_OwnedTensorHandles; + std::unordered_map<LayerBindingId, bool> m_ValidationMap; bool m_IsAllocated; std::mutex m_Mutex; }; |