diff options
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r-- | src/armnn/Runtime.cpp | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index bbcbb9f6f6..085cf2cee8 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -76,6 +76,12 @@ TensorInfo IRuntime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId lay return pRuntimeImpl->GetOutputTensorInfo(networkId, layerId); } +std::vector<ImportedInputId> IRuntime::ImportInputs(NetworkId networkId, const InputTensors& inputTensors) +{ + return pRuntimeImpl->ImportInputs(networkId, inputTensors); +} + + Status IRuntime::EnqueueWorkload(NetworkId networkId, const InputTensors& inputTensors, const OutputTensors& outputTensors) @@ -85,9 +91,10 @@ Status IRuntime::EnqueueWorkload(NetworkId networkId, Status IRuntime::Execute(IWorkingMemHandle& workingMemHandle, const InputTensors& inputTensors, - const OutputTensors& outputTensors) + const OutputTensors& outputTensors, + std::vector<ImportedInputId> preImportedInputs) { - return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors); + return pRuntimeImpl->Execute(workingMemHandle, inputTensors, outputTensors, preImportedInputs); } Status IRuntime::UnloadNetwork(NetworkId networkId) @@ -476,6 +483,12 @@ TensorInfo RuntimeImpl::GetOutputTensorInfo(NetworkId networkId, LayerBindingId return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId); } +std::vector<ImportedInputId> RuntimeImpl::ImportInputs(NetworkId networkId, const InputTensors& inputTensors) +{ + return GetLoadedNetworkPtr(networkId)->ImportInputs(inputTensors); +} + + Status RuntimeImpl::EnqueueWorkload(NetworkId networkId, const InputTensors& inputTensors, @@ -512,7 +525,8 @@ Status RuntimeImpl::EnqueueWorkload(NetworkId networkId, Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle, const InputTensors& inputTensors, - const OutputTensors& outputTensors) + const OutputTensors& outputTensors, + std::vector<ImportedInputId> preImportedInputs) { NetworkId networkId = iWorkingMemHandle.GetNetworkId(); LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); @@ -531,7 +545,7 @@ Status RuntimeImpl::Execute(IWorkingMemHandle& iWorkingMemHandle, ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute"); - return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle); + return loadedNetwork->Execute(inputTensors, outputTensors, iWorkingMemHandle, preImportedInputs); } /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have |