From e48bdff741568236d3c0747ad3d18a8eba5b36dd Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Mon, 3 Sep 2018 13:50:50 +0100 Subject: IVGCVSW-1806 Refactored Android-NN-Driver, added common "getCapabilities", "getSupportedOperations" and "prepareModel" implementations * Added common base ArmnnDriverImpl class * Added common template implementation of the driver's "getCapabilities", "getSupportedOperations" and "prepareModel" methods * Refactored ArmnnPreparedModel and RequestThread to support HAL v1.1 models * Moved "getStatus" to the common base class, as it is shared by both HAL implementations * Refactored the code where necessary Change-Id: I747334730026d63b4002662523fb93608f67c899 --- ArmnnPreparedModel.cpp | 49 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 17 deletions(-) (limited to 'ArmnnPreparedModel.cpp') diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp index d338fdc8..7cbbcbcb 100644 --- a/ArmnnPreparedModel.cpp +++ b/ArmnnPreparedModel.cpp @@ -81,18 +81,20 @@ inline std::string BuildTensorName(const char* tensorNamePrefix, std::size_t ind return tensorNamePrefix + std::to_string(index); } -} +} // anonymous namespace using namespace android::hardware; namespace armnn_driver { -RequestThread ArmnnPreparedModel::m_RequestThread; +template +RequestThread ArmnnPreparedModel::m_RequestThread; +template template -void ArmnnPreparedModel::DumpTensorsIfRequired(char const* tensorNamePrefix, - const TensorBindingCollection& tensorBindings) +void ArmnnPreparedModel::DumpTensorsIfRequired(char const* tensorNamePrefix, + const TensorBindingCollection& tensorBindings) { if (!m_RequestInputsAndOutputsDumpDir.empty()) { @@ -107,11 +109,12 @@ void ArmnnPreparedModel::DumpTensorsIfRequired(char const* tensorNamePrefix, } } -ArmnnPreparedModel::ArmnnPreparedModel(armnn::NetworkId networkId, - armnn::IRuntime* runtime, - const neuralnetworks::V1_0::Model& model, - const std::string& requestInputsAndOutputsDumpDir, - const bool gpuProfilingEnabled) +template +ArmnnPreparedModel::ArmnnPreparedModel(armnn::NetworkId networkId, + armnn::IRuntime* runtime, + const HalModel& model, + const std::string& requestInputsAndOutputsDumpDir, + const bool gpuProfilingEnabled) : m_NetworkId(networkId) , m_Runtime(runtime) , m_Model(model) @@ -123,7 +126,8 @@ ArmnnPreparedModel::ArmnnPreparedModel(armnn::NetworkId networkId, m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); } -ArmnnPreparedModel::~ArmnnPreparedModel() +template +ArmnnPreparedModel::~ArmnnPreparedModel() { // Get a hold of the profiler used by this model. std::shared_ptr profiler = m_Runtime->GetProfiler(m_NetworkId); @@ -135,8 +139,9 @@ ArmnnPreparedModel::~ArmnnPreparedModel() DumpJsonProfilingIfRequired(m_GpuProfilingEnabled, m_RequestInputsAndOutputsDumpDir, m_NetworkId, profiler.get()); } -Return ArmnnPreparedModel::execute(const Request& request, - const ::android::sp& callback) +template +Return ArmnnPreparedModel::execute(const Request& request, + const ::android::sp& callback) { ALOGV("ArmnnPreparedModel::execute(): %s", GetModelSummary(m_Model).c_str()); m_RequestCount++; @@ -220,10 +225,12 @@ Return ArmnnPreparedModel::execute(const Request& request, return ErrorStatus::NONE; // successfully queued } -void ArmnnPreparedModel::ExecuteGraph(std::shared_ptr>& pMemPools, - std::shared_ptr& pInputTensors, - std::shared_ptr& pOutputTensors, - const ::android::sp& callback) +template +void ArmnnPreparedModel::ExecuteGraph( + std::shared_ptr>& pMemPools, + std::shared_ptr& pInputTensors, + std::shared_ptr& pOutputTensors, + const ::android::sp& callback) { ALOGV("ArmnnPreparedModel::ExecuteGraph(...)"); @@ -254,7 +261,8 @@ void ArmnnPreparedModel::ExecuteGraph(std::shared_ptr +void ArmnnPreparedModel::ExecuteWithDummyInputs() { std::vector> storage; armnn::InputTensors inputTensors; @@ -287,4 +295,11 @@ void ArmnnPreparedModel::ExecuteWithDummyInputs() } } +// Class template specializations +template class ArmnnPreparedModel; + +#ifdef ARMNN_ANDROID_NN_V1_1 // Using ::android::hardware::neuralnetworks::V1_1. +template class ArmnnPreparedModel; +#endif + } // namespace armnn_driver -- cgit v1.2.1