aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/WorkingMemHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/WorkingMemHandle.hpp')
-rw-r--r--src/armnn/WorkingMemHandle.hpp33
1 files changed, 31 insertions, 2 deletions
diff --git a/src/armnn/WorkingMemHandle.hpp b/src/armnn/WorkingMemHandle.hpp
index 5e3fd66299..676d04288b 100644
--- a/src/armnn/WorkingMemHandle.hpp
+++ b/src/armnn/WorkingMemHandle.hpp
@@ -21,11 +21,23 @@ namespace armnn
namespace experimental
{
+
class WorkingMemHandle final : public IWorkingMemHandle
{
public:
+ struct InputConnectionInfo
+ {
+ LayerBindingId m_LayerBindingId;
+ unsigned int m_DescriptorIndex;
+ unsigned int m_InputIndex;
+ };
+
+ WorkingMemHandle(NetworkId networkId) : m_NetworkId(networkId){}
+
WorkingMemHandle(NetworkId networkId,
+ std::vector<std::pair<LayerBindingId, LayerGuid>> inputHandles,
+ std::vector<InputConnectionInfo> inputConnections,
std::vector<WorkingMemDescriptor> workingMemDescriptors,
std::unordered_map<LayerGuid, WorkingMemDescriptor> workingMemDescriptorMap,
std::vector<std::shared_ptr<IMemoryManager>> memoryManagers,
@@ -39,8 +51,6 @@ public:
return m_NetworkId;
}
-
-
/// Allocate the backing memory required for execution. If this is not called, then allocation will be
/// deferred to execution time. The mutex must be locked.
void Allocate() override;
@@ -75,10 +85,28 @@ public:
return m_WorkingMemDescriptors[id];
}
+ ITensorHandle* GetInputHandle(LayerBindingId layerBindingId) const
+ {
+ return m_InputHandleMap.at(layerBindingId);
+ };
+
+ const std::vector<std::vector<ITensorHandle*>::iterator>& GetInputConnections(LayerBindingId layerBindingId) const
+ {
+ return m_InputConnectionMap.at(layerBindingId);
+ };
+
+ std::unordered_map<LayerBindingId, bool> GetValidationMap() const
+ {
+ return m_ValidationMap;
+ };
+
private:
NetworkId m_NetworkId;
std::shared_ptr<ProfilerImpl> m_Profiler;
+ std::unordered_map<LayerBindingId, ITensorHandle*> m_InputHandleMap;
+ std::unordered_map<LayerBindingId, std::vector<std::vector<ITensorHandle*>::iterator>> m_InputConnectionMap;
+
std::vector<WorkingMemDescriptor> m_WorkingMemDescriptors;
std::unordered_map<LayerGuid, WorkingMemDescriptor> m_WorkingMemDescriptorMap;
@@ -88,6 +116,7 @@ private:
// constant tensor's can be shared by multiple WorkingMemHandles and so will not be stored here
std::unordered_map<LayerGuid, std::vector<std::unique_ptr<ITensorHandle> > > m_OwnedTensorHandles;
+ std::unordered_map<LayerBindingId, bool> m_ValidationMap;
bool m_IsAllocated;
std::mutex m_Mutex;
};