From f37b970ff96b98310309e78aeea8a2e9df27b15a Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Wed, 1 Sep 2021 18:06:04 +0100 Subject: IVGCVSW-6312 Support pre-importing inputs Signed-off-by: Finn Williams Change-Id: Ifc5e6f2e36767cb2a5cbf281d40ec9989b581abc --- src/armnn/WorkingMemHandle.hpp | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) (limited to 'src/armnn/WorkingMemHandle.hpp') diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp index 5e3fd66299..676d04288b 100644 --- a/src/armnn/WorkingMemHandle.hpp +++ b/src/armnn/WorkingMemHandle.hpp @@ -21,11 +21,23 @@ namespace armnn namespace experimental { + class WorkingMemHandle final : public IWorkingMemHandle { public: + struct InputConnectionInfo + { + LayerBindingId m_LayerBindingId; + unsigned int m_DescriptorIndex; + unsigned int m_InputIndex; + }; + + WorkingMemHandle(NetworkId networkId) : m_NetworkId(networkId){} + WorkingMemHandle(NetworkId networkId, + std::vector> inputHandles, + std::vector inputConnections, std::vector workingMemDescriptors, std::unordered_map workingMemDescriptorMap, std::vector> memoryManagers, @@ -39,8 +51,6 @@ public: return m_NetworkId; } - - /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. void Allocate() override; @@ -75,10 +85,28 @@ public: return m_WorkingMemDescriptors[id]; } + ITensorHandle* GetInputHandle(LayerBindingId layerBindingId) const + { + return m_InputHandleMap.at(layerBindingId); + }; + + const std::vector::iterator>& GetInputConnections(LayerBindingId layerBindingId) const + { + return m_InputConnectionMap.at(layerBindingId); + }; + + std::unordered_map GetValidationMap() const + { + return m_ValidationMap; + }; + private: NetworkId m_NetworkId; std::shared_ptr m_Profiler; + std::unordered_map m_InputHandleMap; + std::unordered_map::iterator>> m_InputConnectionMap; + std::vector m_WorkingMemDescriptors; std::unordered_map m_WorkingMemDescriptorMap; @@ -88,6 +116,7 @@ private: // constant tensor's can be shared by multiple WorkingMemHandles and so will not be stored here std::unordered_map > > m_OwnedTensorHandles; + std::unordered_map m_ValidationMap; bool m_IsAllocated; std::mutex m_Mutex; }; -- cgit v1.2.1