diff options
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r-- | src/armnn/LoadedNetwork.hpp | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 360ad91170..e713be215a 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -49,13 +49,16 @@ public: TensorInfo GetInputTensorInfo(LayerBindingId layerId) const; TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const; + std::vector<ImportedInputId> ImportInputs(const InputTensors& inputTensors); + /// Single thread execution of the loaded network Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors); /// Thread safe execution of the loaded network Status Execute(const InputTensors& inputTensors, const OutputTensors& outputTensors, - IWorkingMemHandle& workingMemHandle); + IWorkingMemHandle& workingMemHandle, + std::vector<ImportedInputId> preImportedInputs = {}); static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net, std::string& errorMessage, @@ -100,7 +103,7 @@ private: void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo); - void EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& handle); + void EnqueueInput(const ConstTensor& inputTensor, ITensorHandle* inputTensorHandle); void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle); @@ -130,6 +133,22 @@ private: TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry; profiling::ProfilingService& m_ProfilingService; + + struct ImportedInputHandlePin + { + ImportedInputHandlePin(LayerBindingId layerBindingId, + std::unique_ptr<ITensorHandle> tensorHandle) + : m_LayerBindingId(layerBindingId) + , m_TensorHandle(std::move(tensorHandle)) + {} + + LayerBindingId m_LayerBindingId; + std::unique_ptr<ITensorHandle> m_TensorHandle; + }; + + std::vector<ImportedInputHandlePin> m_PreImportedInputHandles; + + ImportedInputId m_CurImportedInputId = 0; }; } |