From 0a2dfabd76a45c58d0a14567f0503369c4e6fbf3 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 6 Oct 2021 16:41:44 +0100 Subject: IVGCVSW-5636 'Implement NNAPI caching functions' * Cached serialized ArmNN model. !armnn:6384 Signed-off-by: Sadik Armagan Signed-off-by: Kevin May Change-Id: I78120a7f8ea892a28c0ff25f1b54e67a4f912574 --- ArmnnPreparedModel_1_2.cpp | 61 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 54 insertions(+), 7 deletions(-) (limited to 'ArmnnPreparedModel_1_2.cpp') diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp index 2e378801..7cc75473 100644 --- a/ArmnnPreparedModel_1_2.cpp +++ b/ArmnnPreparedModel_1_2.cpp @@ -159,6 +159,47 @@ ArmnnPreparedModel_1_2::ArmnnPreparedModel_1_2(armnn::NetworkId netw , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir) , m_GpuProfilingEnabled(gpuProfilingEnabled) , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled) + , m_PreparedFromCache(false) +{ + // Enable profiling if required. + m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); + + if (m_AsyncModelExecutionEnabled) + { + std::vector> memHandles; + for (unsigned int i=0; i < numberOfThreads; ++i) + { + memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(networkId)); + } + + if (!m_Threadpool) + { + m_Threadpool = std::make_unique(numberOfThreads, runtime, memHandles); + } + else + { + m_Threadpool->LoadMemHandles(memHandles); + } + + m_WorkingMemHandle = memHandles.back(); + } +} + +template +ArmnnPreparedModel_1_2::ArmnnPreparedModel_1_2(armnn::NetworkId networkId, + armnn::IRuntime* runtime, + const std::string& requestInputsAndOutputsDumpDir, + const bool gpuProfilingEnabled, + const bool asyncModelExecutionEnabled, + const unsigned int numberOfThreads, + const bool preparedFromCache) + : m_NetworkId(networkId) + , m_Runtime(runtime) + , m_RequestCount(0) + , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir) + , m_GpuProfilingEnabled(gpuProfilingEnabled) + , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled) + , m_PreparedFromCache(preparedFromCache) { // Enable profiling if required. m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); @@ -384,7 +425,10 @@ Return ArmnnPreparedModel_1_2::executeSynchronously(const V1_0 V1_2::MeasureTiming measureTiming, executeSynchronously_cb cb) { - ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str()); + if (!m_PreparedFromCache) + { + ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str()); + } m_RequestCount++; if (cb == nullptr) @@ -400,7 +444,7 @@ Return ArmnnPreparedModel_1_2::executeSynchronously(const V1_0 driverStart = Now(); } - if (!android::nn::validateRequest(request, m_Model)) + if (!m_PreparedFromCache && !android::nn::validateRequest(request, m_Model)) { ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model"); cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming); @@ -530,11 +574,11 @@ bool ArmnnPreparedModel_1_2::ExecuteGraph( } template -bool ArmnnPreparedModel_1_2::ExecuteWithDummyInputs() +bool ArmnnPreparedModel_1_2::ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs) { std::vector> storage; armnn::InputTensors inputTensors; - for (unsigned int i = 0; i < getMainModel(m_Model).inputIndexes.size(); i++) + for (unsigned int i = 0; i < numInputs; i++) { const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i); storage.emplace_back(inputTensorInfo.GetNumBytes()); @@ -544,7 +588,7 @@ bool ArmnnPreparedModel_1_2::ExecuteWithDummyInputs() } armnn::OutputTensors outputTensors; - for (unsigned int i = 0; i < getMainModel(m_Model).outputIndexes.size(); i++) + for (unsigned int i = 0; i < numOutputs; i++) { const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i); storage.emplace_back(outputTensorInfo.GetNumBytes()); @@ -576,10 +620,13 @@ Return ArmnnPreparedModel_1_2::Execute(const V1_ ctx.driverStart = Now(); } - ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str()); + if (!m_PreparedFromCache) + { + ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str()); + } m_RequestCount++; - if (!android::nn::validateRequest(request, m_Model)) + if (!m_PreparedFromCache && !android::nn::validateRequest(request, m_Model)) { callback(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); return V1_0::ErrorStatus::INVALID_ARGUMENT; -- cgit v1.2.1