aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/LoadedNetwork.hpp')
-rw-r--r--src/armnn/LoadedNetwork.hpp23
1 files changed, 21 insertions, 2 deletions
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 360ad91170..e713be215a 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -49,13 +49,16 @@ public:
TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
+ std::vector<ImportedInputId> ImportInputs(const InputTensors& inputTensors);
+
/// Single thread execution of the loaded network
Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
/// Thread safe execution of the loaded network
Status Execute(const InputTensors& inputTensors,
const OutputTensors& outputTensors,
- IWorkingMemHandle& workingMemHandle);
+ IWorkingMemHandle& workingMemHandle,
+ std::vector<ImportedInputId> preImportedInputs = {});
static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
std::string& errorMessage,
@@ -100,7 +103,7 @@ private:
void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
- void EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& handle);
+ void EnqueueInput(const ConstTensor& inputTensor, ITensorHandle* inputTensorHandle);
void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle);
@@ -130,6 +133,22 @@ private:
TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
profiling::ProfilingService& m_ProfilingService;
+
+ struct ImportedInputHandlePin
+ {
+ ImportedInputHandlePin(LayerBindingId layerBindingId,
+ std::unique_ptr<ITensorHandle> tensorHandle)
+ : m_LayerBindingId(layerBindingId)
+ , m_TensorHandle(std::move(tensorHandle))
+ {}
+
+ LayerBindingId m_LayerBindingId;
+ std::unique_ptr<ITensorHandle> m_TensorHandle;
+ };
+
+ std::vector<ImportedInputHandlePin> m_PreImportedInputHandles;
+
+ ImportedInputId m_CurImportedInputId = 0;
};
}