aboutsummaryrefslogtreecommitdiff
path: root/RequestThread.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'RequestThread.cpp')
-rw-r--r--RequestThread.cpp32
1 files changed, 22 insertions, 10 deletions
diff --git a/RequestThread.cpp b/RequestThread.cpp
index abaee90c..8e44d8d2 100644
--- a/RequestThread.cpp
+++ b/RequestThread.cpp
@@ -17,13 +17,15 @@ using namespace android;
namespace armnn_driver
{
-RequestThread::RequestThread()
+template<typename HalVersion>
+RequestThread<HalVersion>::RequestThread()
{
ALOGV("RequestThread::RequestThread()");
m_Thread = std::make_unique<std::thread>(&RequestThread::Process, this);
}
-RequestThread::~RequestThread()
+template<typename HalVersion>
+RequestThread<HalVersion>::~RequestThread()
{
ALOGV("RequestThread::~RequestThread()");
@@ -48,11 +50,12 @@ RequestThread::~RequestThread()
catch (const std::exception&) { } // Swallow any exception.
}
-void RequestThread::PostMsg(ArmnnPreparedModel* model,
- std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
- std::shared_ptr<armnn::InputTensors>& inputTensors,
- std::shared_ptr<armnn::OutputTensors>& outputTensors,
- const ::android::sp<IExecutionCallback>& callback)
+template<typename HalVersion>
+void RequestThread<HalVersion>::PostMsg(ArmnnPreparedModel<HalVersion>* model,
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ const ::android::sp<IExecutionCallback>& callback)
{
ALOGV("RequestThread::PostMsg(...)");
auto data = std::make_shared<AsyncExecuteData>(model,
@@ -64,7 +67,8 @@ void RequestThread::PostMsg(ArmnnPreparedModel* model,
PostMsg(pMsg);
}
-void RequestThread::PostMsg(std::shared_ptr<ThreadMsg>& pMsg)
+template<typename HalVersion>
+void RequestThread<HalVersion>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg)
{
ALOGV("RequestThread::PostMsg(pMsg)");
// Add a message to the queue and notify the request thread
@@ -73,7 +77,8 @@ void RequestThread::PostMsg(std::shared_ptr<ThreadMsg>& pMsg)
m_Cv.notify_one();
}
-void RequestThread::Process()
+template<typename HalVersion>
+void RequestThread<HalVersion>::Process()
{
ALOGV("RequestThread::Process()");
while (true)
@@ -98,7 +103,7 @@ void RequestThread::Process()
{
ALOGV("RequestThread::Process() - request");
// invoke the asynchronous execution method
- ArmnnPreparedModel* model = pMsg->data->m_Model;
+ ArmnnPreparedModel<HalVersion>* model = pMsg->data->m_Model;
model->ExecuteGraph(pMsg->data->m_MemPools,
pMsg->data->m_InputTensors,
pMsg->data->m_OutputTensors,
@@ -126,5 +131,12 @@ void RequestThread::Process()
}
}
+// Class template specializations
+template class RequestThread<HalVersion_1_0>;
+
+#ifdef ARMNN_ANDROID_NN_V1_1 // Using ::android::hardware::neuralnetworks::V1_1.
+template class RequestThread<HalVersion_1_1>;
+#endif
+
} // namespace armnn_driver