aboutsummaryrefslogtreecommitdiff
path: root/RequestThread_1_3.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'RequestThread_1_3.cpp')
-rw-r--r--RequestThread_1_3.cpp193
1 files changed, 193 insertions, 0 deletions
diff --git a/RequestThread_1_3.cpp b/RequestThread_1_3.cpp
new file mode 100644
index 00000000..59fa70ed
--- /dev/null
+++ b/RequestThread_1_3.cpp
@@ -0,0 +1,193 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#define LOG_TAG "ArmnnDriver"
+
+#include "RequestThread_1_3.hpp"
+
+#include "ArmnnPreparedModel_1_3.hpp"
+
+#include <armnn/utility/Assert.hpp>
+
+#include <log/log.h>
+
+using namespace android;
+
+namespace armnn_driver
+{
+
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::RequestThread_1_3()
+{
+ ALOGV("RequestThread_1_3::RequestThread_1_3()");
+ m_Thread = std::make_unique<std::thread>(&RequestThread_1_3::Process, this);
+}
+
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::~RequestThread_1_3()
+{
+ ALOGV("RequestThread_1_3::~RequestThread_1_3()");
+
+ try
+ {
+ // Coverity fix: The following code may throw an exception of type std::length_error.
+
+ // This code is meant to to terminate the inner thread gracefully by posting an EXIT message
+ // to the thread's message queue. However, according to Coverity, this code could throw an exception and fail.
+ // Since only one static instance of RequestThread is used in the driver (in ArmnnPreparedModel),
+ // this destructor is called only when the application has been closed, which means that
+ // the inner thread will be terminated anyway, although abruptly, in the event that the destructor code throws.
+ // Wrapping the destructor's code with a try-catch block simply fixes the Coverity bug.
+
+ // Post an EXIT message to the thread
+ std::shared_ptr<AsyncExecuteData> nulldata(nullptr);
+ auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::EXIT, nulldata);
+ PostMsg(pMsg);
+ // Wait for the thread to terminate, it is deleted automatically
+ m_Thread->join();
+ }
+ catch (const std::exception&) { } // Swallow any exception.
+}
+
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::PostMsg(PreparedModel<HalVersion>* model,
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
+ std::shared_ptr<armnn::InputTensors>& inputTensors,
+ std::shared_ptr<armnn::OutputTensors>& outputTensors,
+ CallbackContext callbackContext)
+{
+ ALOGV("RequestThread_1_3::PostMsg(...)");
+ auto data = std::make_shared<AsyncExecuteData>(model,
+ memPools,
+ inputTensors,
+ outputTensors,
+ callbackContext);
+ auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::REQUEST, data);
+ PostMsg(pMsg, model->GetModelPriority());
+}
+
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg,
+ V1_3::Priority priority)
+{
+ ALOGV("RequestThread_1_3::PostMsg(pMsg)");
+ // Add a message to the queue and notify the request thread
+ std::unique_lock<std::mutex> lock(m_Mutex);
+ switch (priority) {
+ case V1_3::Priority::HIGH:
+ m_HighPriorityQueue.push(pMsg);
+ break;
+ case V1_3::Priority::LOW:
+ m_LowPriorityQueue.push(pMsg);
+ break;
+ case V1_3::Priority::MEDIUM:
+ default:
+ m_MediumPriorityQueue.push(pMsg);
+ }
+ m_Cv.notify_one();
+}
+
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::Process()
+{
+ ALOGV("RequestThread_1_3::Process()");
+ int retireRate = RETIRE_RATE;
+ int highPriorityCount = 0;
+ int mediumPriorityCount = 0;
+ while (true)
+ {
+ std::shared_ptr<ThreadMsg> pMsg(nullptr);
+ {
+ // Wait for a message to be added to the queue
+ // This is in a separate scope to minimise the lifetime of the lock
+ std::unique_lock<std::mutex> lock(m_Mutex);
+ while (m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && m_LowPriorityQueue.empty())
+ {
+ m_Cv.wait(lock);
+ }
+ // Get the message to process from the front of each queue based on priority from high to low
+ // Get high priority first if it does not exceed the retire rate
+ if (!m_HighPriorityQueue.empty() && highPriorityCount < retireRate)
+ {
+ pMsg = m_HighPriorityQueue.front();
+ m_HighPriorityQueue.pop();
+ highPriorityCount += 1;
+ }
+ // If high priority queue is empty or the count exceeds the retire rate, get medium priority message
+ else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < retireRate)
+ {
+ pMsg = m_MediumPriorityQueue.front();
+ m_MediumPriorityQueue.pop();
+ mediumPriorityCount += 1;
+ // Reset high priority count
+ highPriorityCount = 0;
+ }
+ // If medium priority queue is empty or the count exceeds the retire rate, get low priority message
+ else if (!m_LowPriorityQueue.empty())
+ {
+ pMsg = m_LowPriorityQueue.front();
+ m_LowPriorityQueue.pop();
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ }
+ else
+ {
+ // Reset high and medium priority count
+ highPriorityCount = 0;
+ mediumPriorityCount = 0;
+ continue;
+ }
+ }
+
+ switch (pMsg->type)
+ {
+ case ThreadMsgType::REQUEST:
+ {
+ ALOGV("RequestThread_1_3::Process() - request");
+ // invoke the asynchronous execution method
+ PreparedModel<HalVersion>* model = pMsg->data->m_Model;
+ model->ExecuteGraph(pMsg->data->m_MemPools,
+ *(pMsg->data->m_InputTensors),
+ *(pMsg->data->m_OutputTensors),
+ pMsg->data->m_CallbackContext);
+ break;
+ }
+
+ case ThreadMsgType::EXIT:
+ {
+ ALOGV("RequestThread_1_3::Process() - exit");
+ // delete all remaining messages (there should not be any)
+ std::unique_lock<std::mutex> lock(m_Mutex);
+ while (!m_HighPriorityQueue.empty())
+ {
+ m_HighPriorityQueue.pop();
+ }
+ while (!m_MediumPriorityQueue.empty())
+ {
+ m_MediumPriorityQueue.pop();
+ }
+ while (!m_LowPriorityQueue.empty())
+ {
+ m_LowPriorityQueue.pop();
+ }
+ return;
+ }
+
+ default:
+ // this should be unreachable
+ ALOGE("RequestThread_1_3::Process() - invalid message type");
+ ARMNN_ASSERT_MSG(false, "ArmNN: RequestThread_1_3: invalid message type");
+ }
+ }
+}
+
+///
+/// Class template specializations
+///
+
+template class RequestThread_1_3<ArmnnPreparedModel_1_3, hal_1_3::HalPolicy, CallbackContext_1_3>;
+
+} // namespace armnn_driver