diff options
Diffstat (limited to 'RequestThread.cpp')
-rw-r--r-- | RequestThread.cpp | 44 |
1 files changed, 27 insertions, 17 deletions
diff --git a/RequestThread.cpp b/RequestThread.cpp index bc1eccc0..4b646034 100644 --- a/RequestThread.cpp +++ b/RequestThread.cpp @@ -8,6 +8,10 @@ #include "RequestThread.hpp" #include "ArmnnPreparedModel.hpp" +#ifdef ARMNN_ANDROID_NN_V1_2 +#include "ArmnnPreparedModel_1_2.hpp" +#endif + #include <boost/assert.hpp> #include <log/log.h> @@ -17,15 +21,15 @@ using namespace android; namespace armnn_driver { -template<typename HalVersion> -RequestThread<HalVersion>::RequestThread() +template <template <typename HalVersion> class PreparedModel, typename HalVersion> +RequestThread<PreparedModel, HalVersion>::RequestThread() { ALOGV("RequestThread::RequestThread()"); m_Thread = std::make_unique<std::thread>(&RequestThread::Process, this); } -template<typename HalVersion> -RequestThread<HalVersion>::~RequestThread() +template <template <typename HalVersion> class PreparedModel, typename HalVersion> +RequestThread<PreparedModel, HalVersion>::~RequestThread() { ALOGV("RequestThread::~RequestThread()"); @@ -50,12 +54,12 @@ RequestThread<HalVersion>::~RequestThread() catch (const std::exception&) { } // Swallow any exception. } -template<typename HalVersion> -void RequestThread<HalVersion>::PostMsg(ArmnnPreparedModel<HalVersion>* model, - std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools, - std::shared_ptr<armnn::InputTensors>& inputTensors, - std::shared_ptr<armnn::OutputTensors>& outputTensors, - const ::android::sp<V1_0::IExecutionCallback>& callback) +template <template <typename HalVersion> class PreparedModel, typename HalVersion> +void RequestThread<PreparedModel, HalVersion>::PostMsg(PreparedModel<HalVersion>* model, + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + const ::android::sp<V1_0::IExecutionCallback>& callback) { ALOGV("RequestThread::PostMsg(...)"); auto data = std::make_shared<AsyncExecuteData>(model, @@ -67,8 +71,8 @@ void RequestThread<HalVersion>::PostMsg(ArmnnPreparedModel<HalVersion>* model, PostMsg(pMsg); } -template<typename HalVersion> -void RequestThread<HalVersion>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) +template <template <typename HalVersion> class PreparedModel, typename HalVersion> +void RequestThread<PreparedModel, HalVersion>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) { ALOGV("RequestThread::PostMsg(pMsg)"); // Add a message to the queue and notify the request thread @@ -77,8 +81,8 @@ void RequestThread<HalVersion>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) m_Cv.notify_one(); } -template<typename HalVersion> -void RequestThread<HalVersion>::Process() +template <template <typename HalVersion> class PreparedModel, typename HalVersion> +void RequestThread<PreparedModel, HalVersion>::Process() { ALOGV("RequestThread::Process()"); while (true) @@ -103,7 +107,7 @@ void RequestThread<HalVersion>::Process() { ALOGV("RequestThread::Process() - request"); // invoke the asynchronous execution method - ArmnnPreparedModel<HalVersion>* model = pMsg->data->m_Model; + PreparedModel<HalVersion>* model = pMsg->data->m_Model; model->ExecuteGraph(pMsg->data->m_MemPools, pMsg->data->m_InputTensors, pMsg->data->m_OutputTensors, @@ -135,10 +139,16 @@ void RequestThread<HalVersion>::Process() /// Class template specializations /// -template class RequestThread<hal_1_0::HalPolicy>; +template class RequestThread<ArmnnPreparedModel, hal_1_0::HalPolicy>; #ifdef ARMNN_ANDROID_NN_V1_1 -template class RequestThread<hal_1_1::HalPolicy>; +template class RequestThread<armnn_driver::ArmnnPreparedModel, hal_1_1::HalPolicy>; +#endif + +#ifdef ARMNN_ANDROID_NN_V1_2 +template class RequestThread<ArmnnPreparedModel, hal_1_1::HalPolicy>; +template class RequestThread<ArmnnPreparedModel, hal_1_2::HalPolicy>; +template class RequestThread<ArmnnPreparedModel_1_2, hal_1_2::HalPolicy>; #endif } // namespace armnn_driver |