diff options
Diffstat (limited to 'RequestThread.cpp')
-rw-r--r-- | RequestThread.cpp | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/RequestThread.cpp b/RequestThread.cpp index abaee90c..8e44d8d2 100644 --- a/RequestThread.cpp +++ b/RequestThread.cpp @@ -17,13 +17,15 @@ using namespace android; namespace armnn_driver { -RequestThread::RequestThread() +template<typename HalVersion> +RequestThread<HalVersion>::RequestThread() { ALOGV("RequestThread::RequestThread()"); m_Thread = std::make_unique<std::thread>(&RequestThread::Process, this); } -RequestThread::~RequestThread() +template<typename HalVersion> +RequestThread<HalVersion>::~RequestThread() { ALOGV("RequestThread::~RequestThread()"); @@ -48,11 +50,12 @@ RequestThread::~RequestThread() catch (const std::exception&) { } // Swallow any exception. } -void RequestThread::PostMsg(ArmnnPreparedModel* model, - std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools, - std::shared_ptr<armnn::InputTensors>& inputTensors, - std::shared_ptr<armnn::OutputTensors>& outputTensors, - const ::android::sp<IExecutionCallback>& callback) +template<typename HalVersion> +void RequestThread<HalVersion>::PostMsg(ArmnnPreparedModel<HalVersion>* model, + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools, + std::shared_ptr<armnn::InputTensors>& inputTensors, + std::shared_ptr<armnn::OutputTensors>& outputTensors, + const ::android::sp<IExecutionCallback>& callback) { ALOGV("RequestThread::PostMsg(...)"); auto data = std::make_shared<AsyncExecuteData>(model, @@ -64,7 +67,8 @@ void RequestThread::PostMsg(ArmnnPreparedModel* model, PostMsg(pMsg); } -void RequestThread::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) +template<typename HalVersion> +void RequestThread<HalVersion>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) { ALOGV("RequestThread::PostMsg(pMsg)"); // Add a message to the queue and notify the request thread @@ -73,7 +77,8 @@ void RequestThread::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) m_Cv.notify_one(); } -void RequestThread::Process() +template<typename HalVersion> +void RequestThread<HalVersion>::Process() { ALOGV("RequestThread::Process()"); while (true) @@ -98,7 +103,7 @@ void RequestThread::Process() { ALOGV("RequestThread::Process() - request"); // invoke the asynchronous execution method - ArmnnPreparedModel* model = pMsg->data->m_Model; + ArmnnPreparedModel<HalVersion>* model = pMsg->data->m_Model; model->ExecuteGraph(pMsg->data->m_MemPools, pMsg->data->m_InputTensors, pMsg->data->m_OutputTensors, @@ -126,5 +131,12 @@ void RequestThread::Process() } } +// Class template specializations +template class RequestThread<HalVersion_1_0>; + +#ifdef ARMNN_ANDROID_NN_V1_1 // Using ::android::hardware::neuralnetworks::V1_1. +template class RequestThread<HalVersion_1_1>; +#endif + } // namespace armnn_driver |