diff options
Diffstat (limited to 'ArmnnPreparedModel_1_2.cpp')
-rw-r--r-- | ArmnnPreparedModel_1_2.cpp | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp index 7f35e60f..37bc3a49 100644 --- a/ArmnnPreparedModel_1_2.cpp +++ b/ArmnnPreparedModel_1_2.cpp @@ -9,6 +9,8 @@ #include "Utils.hpp" +#include <armnn/Types.hpp> + #include <log/log.h> #include <OperationsUtils.h> #include <ExecutionBurstServer.h> @@ -151,7 +153,9 @@ ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId netw const std::string& requestInputsAndOutputsDumpDir, const bool gpuProfilingEnabled, const bool asyncModelExecutionEnabled, - const unsigned int numberOfThreads) + const unsigned int numberOfThreads, + const bool importEnabled, + const bool exportEnabled) : m_NetworkId(networkId) , m_Runtime(runtime) , m_Model(model) @@ -159,6 +163,8 @@ ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId netw , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir) , m_GpuProfilingEnabled(gpuProfilingEnabled) , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled) + , m_EnableImport(importEnabled) + , m_EnableExport(exportEnabled) , m_PreparedFromCache(false) { // Enable profiling if required. @@ -192,6 +198,8 @@ ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId netw const bool gpuProfilingEnabled, const bool asyncModelExecutionEnabled, const unsigned int numberOfThreads, + const bool importEnabled, + const bool exportEnabled, const bool preparedFromCache) : m_NetworkId(networkId) , m_Runtime(runtime) @@ -199,6 +207,8 @@ ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId netw , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir) , m_GpuProfilingEnabled(gpuProfilingEnabled) , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled) + , m_EnableImport(importEnabled) + , m_EnableExport(exportEnabled) , m_PreparedFromCache(preparedFromCache) { // Enable profiling if required. @@ -531,7 +541,20 @@ bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph( else { ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled false"); - status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); + + // Create a vector of Input and Output Ids which can be imported. An empty vector means all will be copied. + std::vector<armnn::ImportedInputId> importedInputIds; + if (m_EnableImport) + { + importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc); + } + std::vector<armnn::ImportedOutputId> importedOutputIds; + if (m_EnableExport) + { + importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc); + } + status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors, + importedInputIds, importedOutputIds); } if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES) |