diff options
Diffstat (limited to 'ArmnnPreparedModel.cpp')
-rw-r--r-- | ArmnnPreparedModel.cpp | 49 |
1 files changed, 32 insertions, 17 deletions
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp index d338fdc8..7cbbcbcb 100644 --- a/ArmnnPreparedModel.cpp +++ b/ArmnnPreparedModel.cpp @@ -81,18 +81,20 @@ inline std::string BuildTensorName(const char* tensorNamePrefix, std::size_t ind return tensorNamePrefix + std::to_string(index); } -} +} // anonymous namespace using namespace android::hardware; namespace armnn_driver { -RequestThread ArmnnPreparedModel::m_RequestThread; +template<typename HalVersion> +RequestThread<HalVersion> ArmnnPreparedModel<HalVersion>::m_RequestThread; +template<typename HalVersion> template <typename TensorBindingCollection> -void ArmnnPreparedModel::DumpTensorsIfRequired(char const* tensorNamePrefix, - const TensorBindingCollection& tensorBindings) +void ArmnnPreparedModel<HalVersion>::DumpTensorsIfRequired(char const* tensorNamePrefix, + const TensorBindingCollection& tensorBindings) { if (!m_RequestInputsAndOutputsDumpDir.empty()) { @@ -107,11 +109,12 @@ void ArmnnPreparedModel::DumpTensorsIfRequired(char const* tensorNamePrefix, } } -ArmnnPreparedModel::ArmnnPreparedModel(armnn::NetworkId networkId, - armnn::IRuntime* runtime, - const neuralnetworks::V1_0::Model& model, - const std::string& requestInputsAndOutputsDumpDir, - const bool gpuProfilingEnabled) +template<typename HalVersion> +ArmnnPreparedModel<HalVersion>::ArmnnPreparedModel(armnn::NetworkId networkId, + armnn::IRuntime* runtime, + const HalModel& model, + const std::string& requestInputsAndOutputsDumpDir, + const bool gpuProfilingEnabled) : m_NetworkId(networkId) , m_Runtime(runtime) , m_Model(model) @@ -123,7 +126,8 @@ ArmnnPreparedModel::ArmnnPreparedModel(armnn::NetworkId networkId, m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); } -ArmnnPreparedModel::~ArmnnPreparedModel() +template<typename HalVersion> +ArmnnPreparedModel<HalVersion>::~ArmnnPreparedModel() { // Get a hold of the profiler used by this model. std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId); @@ -135,8 +139,9 @@ ArmnnPreparedModel::~ArmnnPreparedModel() DumpJsonProfilingIfRequired(m_GpuProfilingEnabled, m_RequestInputsAndOutputsDumpDir, m_NetworkId, profiler.get()); } -Return<ErrorStatus> ArmnnPreparedModel::execute(const Request& request, - const ::android::sp<IExecutionCallback>& callback) +template<typename HalVersion> +Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& request, + const ::android::sp<IExecutionCallback>& callback) { ALOGV("ArmnnPreparedModel::execute(): %s", GetModelSummary(m_Model).c_str()); m_RequestCount++; @@ -220,10 +225,12 @@ Return<ErrorStatus> ArmnnPreparedModel::execute(const Request& request, return ErrorStatus::NONE; // successfully queued } -void ArmnnPreparedModel::ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, - std::shared_ptr<armnn::InputTensors>& pInputTensors, - std::shared_ptr<armnn::OutputTensors>& pOutputTensors, - const ::android::sp<IExecutionCallback>& callback) +template<typename HalVersion> +void ArmnnPreparedModel<HalVersion>::ExecuteGraph( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + std::shared_ptr<armnn::InputTensors>& pInputTensors, + std::shared_ptr<armnn::OutputTensors>& pOutputTensors, + const ::android::sp<IExecutionCallback>& callback) { ALOGV("ArmnnPreparedModel::ExecuteGraph(...)"); @@ -254,7 +261,8 @@ void ArmnnPreparedModel::ExecuteGraph(std::shared_ptr<std::vector<::android::nn: NotifyCallbackAndCheck(callback, ErrorStatus::NONE, "ExecuteGraph"); } -void ArmnnPreparedModel::ExecuteWithDummyInputs() +template<typename HalVersion> +void ArmnnPreparedModel<HalVersion>::ExecuteWithDummyInputs() { std::vector<std::vector<char>> storage; armnn::InputTensors inputTensors; @@ -287,4 +295,11 @@ void ArmnnPreparedModel::ExecuteWithDummyInputs() } } +// Class template specializations +template class ArmnnPreparedModel<HalVersion_1_0>; + +#ifdef ARMNN_ANDROID_NN_V1_1 // Using ::android::hardware::neuralnetworks::V1_1. +template class ArmnnPreparedModel<HalVersion_1_1>; +#endif + } // namespace armnn_driver |