aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavide Grohmann <davide.grohmann@arm.com>2022-03-24 15:38:27 +0100
committerDavide Grohmann <davide.grohmann@arm.com>2022-05-09 09:36:57 +0200
commit160001cdc876e2e8b9fdc4453dc72081c809ed2d (patch)
tree583506efd1571715abe66c1a52dc8d9517df7ec9
parent8b53aad76ea95dc1f4c8ce64b6f8dc14f727b46f (diff)
downloadethos-u-core-platform-160001cdc876e2e8b9fdc4453dc72081c809ed2d.tar.gz
Add support for cancelling enqueued inferences
If an enqueued inference is cancelled it is simply removed from the queue. Both a cancel inference response with status done and a inference response with status cancelled are sent back. Also override the new operators to call in the FreeRTOS allocator instead of malloc/free. Change-Id: I243e678aa6b996084c9b9be1d1b00ffcecc75bc9
-rw-r--r--applications/message_handler/main.cpp77
-rw-r--r--applications/message_handler/message_handler.cpp39
-rw-r--r--applications/message_handler/message_handler.hpp86
3 files changed, 164 insertions, 38 deletions
diff --git a/applications/message_handler/main.cpp b/applications/message_handler/main.cpp
index dde5dc5..b527840 100644
--- a/applications/message_handler/main.cpp
+++ b/applications/message_handler/main.cpp
@@ -71,10 +71,6 @@ __attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outp
namespace {
-SemaphoreHandle_t messageNotify;
-QueueHandle_t inferenceInputQueue;
-QueueHandle_t inferenceOutputQueue;
-
// Mailbox driver
#ifdef MHU_V2
Mailbox::MHUv2 mailbox(MHU_TX_BASE_ADDRESS, MHU_RX_BASE_ADDRESS); // txBase, rxBase
@@ -87,6 +83,26 @@ Mailbox::MHUDummy mailbox;
} // namespace
/****************************************************************************
+ * Override new operators to call in FreeRTOS allocator
+ ****************************************************************************/
+
+void *operator new(size_t size) {
+ return pvPortMalloc(size);
+}
+
+void *operator new[](size_t size) {
+ return pvPortMalloc(size);
+}
+
+void operator delete(void *ptr) {
+ vPortFree(ptr);
+}
+
+void operator delete[](void *ptr) {
+ vPortFree(ptr);
+}
+
+/****************************************************************************
* Mutex & Semaphore
****************************************************************************/
@@ -150,6 +166,24 @@ int ethosu_semaphore_give(void *sem) {
* Application
****************************************************************************/
+struct TaskParams {
+ TaskParams() :
+ messageNotify(xSemaphoreCreateBinary()),
+ inferenceInputQueue(std::make_shared<Queue<ethosu_core_inference_req>>()),
+ inferenceOutputQueue(xQueueCreate(10, sizeof(ethosu_core_inference_rsp))) {}
+
+ SemaphoreHandle_t messageNotify;
+ // Used to pass inference requests to the inference runner task
+ std::shared_ptr<Queue<ethosu_core_inference_req>> inferenceInputQueue;
+ // Queue for message responses to the remote host
+ QueueHandle_t inferenceOutputQueue;
+};
+
+struct InferenceTaskParams {
+ TaskParams *taskParams;
+ uint8_t *arena;
+};
+
namespace {
#ifdef MHU_IRQ
@@ -160,21 +194,27 @@ void mailboxIrqHandler() {
void inferenceTask(void *pvParameters) {
printf("Starting inference task\n");
+ InferenceTaskParams *params = reinterpret_cast<InferenceTaskParams *>(pvParameters);
+
+ InferenceHandler process(params->arena,
+ arenaSize,
+ params->taskParams->inferenceInputQueue,
+ params->taskParams->inferenceOutputQueue,
+ params->taskParams->messageNotify);
- uint8_t *arena = reinterpret_cast<uint8_t *>(pvParameters);
- InferenceHandler process(arena, arenaSize, inferenceInputQueue, inferenceOutputQueue, messageNotify);
process.run();
}
-void messageTask(void *) {
- printf("Starting input message task\n");
+void messageTask(void *pvParameters) {
+ printf("Starting message task\n");
+ TaskParams *params = reinterpret_cast<TaskParams *>(pvParameters);
IncomingMessageHandler process(*inputMessageQueue.toQueue(),
*outputMessageQueue.toQueue(),
mailbox,
- inferenceInputQueue,
- inferenceOutputQueue,
- messageNotify);
+ params->inferenceInputQueue,
+ params->inferenceOutputQueue,
+ params->messageNotify);
#ifdef MHU_IRQ
// Register mailbox interrupt handler
@@ -196,21 +236,22 @@ int main() {
return 1;
}
- // Create message queues for inter process communication
- messageNotify = xSemaphoreCreateBinary();
- inferenceInputQueue = xQueueCreate(10, sizeof(ethosu_core_inference_req));
- inferenceOutputQueue = xQueueCreate(10, sizeof(ethosu_core_inference_rsp));
+ TaskParams taskParams;
- // Task for handling incoming messages from the remote host
- ret = xTaskCreate(messageTask, "messageTask", 1024, nullptr, 2, nullptr);
+ // Task for handling incoming /outgoing messages from the remote host
+ ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr);
if (ret != pdPASS) {
printf("Failed to create 'messageTask'\n");
return ret;
}
+ InferenceTaskParams infParams[NUM_PARALLEL_TASKS];
+
// One inference task for each NPU
for (size_t n = 0; n < NUM_PARALLEL_TASKS; n++) {
- ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &tensorArena[n], 3, nullptr);
+ infParams[n].taskParams = &taskParams;
+ infParams[n].arena = reinterpret_cast<uint8_t *>(&tensorArena[n]);
+ ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &infParams[n], 3, nullptr);
if (ret != pdPASS) {
printf("Failed to create 'inferenceTask%d'\n", n);
return ret;
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp
index f06d056..f109dc8 100644
--- a/applications/message_handler/message_handler.cpp
+++ b/applications/message_handler/message_handler.cpp
@@ -138,12 +138,13 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s
}; // namespace
-IncomingMessageHandler::IncomingMessageHandler(EthosU::ethosu_core_queue &_inputMessageQueue,
- EthosU::ethosu_core_queue &_outputMessageQueue,
- Mailbox::Mailbox &_mailbox,
- QueueHandle_t _inferenceInputQueue,
- QueueHandle_t _inferenceOutputQueue,
- SemaphoreHandle_t _messageNotify) :
+IncomingMessageHandler::IncomingMessageHandler(
+ EthosU::ethosu_core_queue &_inputMessageQueue,
+ EthosU::ethosu_core_queue &_outputMessageQueue,
+ Mailbox::Mailbox &_mailbox,
+ std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> _inferenceInputQueue,
+ QueueHandle_t _inferenceOutputQueue,
+ SemaphoreHandle_t _messageNotify) :
inputMessageQueue(_inputMessageQueue),
outputMessageQueue(_outputMessageQueue), mailbox(_mailbox), inferenceInputQueue(_inferenceInputQueue),
inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify) {
@@ -166,7 +167,7 @@ void IncomingMessageHandler::handleIrq(void *userArg) {
return;
}
IncomingMessageHandler *_this = reinterpret_cast<IncomingMessageHandler *>(userArg);
- xSemaphoreGive(_this->messageNotify);
+ xSemaphoreGiveFromISR(_this->messageNotify, nullptr);
}
void IncomingMessageHandler::sendErrorAndResetQueue(ethosu_core_msg_err_type type, const char *message) {
@@ -287,7 +288,7 @@ bool IncomingMessageHandler::handleMessage() {
}
printf("]\n");
- if (pdTRUE != xQueueSend(inferenceInputQueue, &req, 0)) {
+ if (!inferenceInputQueue->push(req)) {
printf("Msg: Inference queue full. Rejecting inference user_arg=0x%" PRIx64 "\n", req.user_arg);
sendFailedInferenceRsp(req.user_arg, ETHOSU_CORE_STATUS_REJECTED);
}
@@ -303,7 +304,15 @@ bool IncomingMessageHandler::handleMessage() {
req.user_arg,
req.inference_handle);
- sendCancelInferenceRsp(req.user_arg, ETHOSU_CORE_STATUS_ERROR);
+ bool found =
+ inferenceInputQueue->erase([req](auto &inf_req) { return inf_req.user_arg == req.inference_handle; });
+
+ // NOTE: send an inference response with status ABORTED if the inference has been droped from the queue
+ if (found) {
+ sendFailedInferenceRsp(req.inference_handle, ETHOSU_CORE_STATUS_ABORTED);
+ }
+
+ sendCancelInferenceRsp(req.user_arg, found ? ETHOSU_CORE_STATUS_OK : ETHOSU_CORE_STATUS_ERROR);
break;
}
case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: {
@@ -450,22 +459,20 @@ void IncomingMessageHandler::readCapabilties(ethosu_core_msg_capabilities_rsp &r
* InferenceHandler
****************************************************************************/
-InferenceHandler::InferenceHandler(uint8_t *_tensorArena,
- size_t _arenaSize,
- QueueHandle_t _inferenceInputQueue,
+InferenceHandler::InferenceHandler(uint8_t *tensorArena,
+ size_t arenaSize,
+ std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> _inferenceInputQueue,
QueueHandle_t _inferenceOutputQueue,
SemaphoreHandle_t _messageNotify) :
inferenceInputQueue(_inferenceInputQueue),
- inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify), inference(_tensorArena, _arenaSize) {}
+ inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify), inference(tensorArena, arenaSize) {}
void InferenceHandler::run() {
ethosu_core_inference_req req;
ethosu_core_inference_rsp rsp;
while (true) {
- if (pdTRUE != xQueueReceive(inferenceInputQueue, &req, portMAX_DELAY)) {
- continue;
- }
+ inferenceInputQueue->pop(req);
runInference(req, rsp);
diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp
index fa79205..dd05059 100644
--- a/applications/message_handler/message_handler.hpp
+++ b/applications/message_handler/message_handler.hpp
@@ -24,6 +24,7 @@
#include "semphr.h"
#include "message_queue.hpp"
+#include <ethosu_core_interface.h>
#if defined(ETHOSU)
#include <ethosu_driver.h>
#endif
@@ -31,20 +32,97 @@
#include <inference_process.hpp>
#include <mailbox.hpp>
+#include <algorithm>
#include <cstddef>
#include <cstdio>
+#include <list>
#include <vector>
namespace MessageHandler {
+template <typename T, size_t capacity = 10>
+class Queue {
+public:
+ using Predicate = std::function<bool(const T &data)>;
+
+ Queue() {
+ mutex = xSemaphoreCreateMutex();
+ size = xSemaphoreCreateCounting(capacity, 0u);
+
+ if (mutex == nullptr || size == nullptr) {
+ printf("Error: failed to allocate memory for inference queue\n");
+ }
+ }
+
+ ~Queue() {
+ vSemaphoreDelete(mutex);
+ vSemaphoreDelete(size);
+ }
+
+ bool push(const T &data) {
+ xSemaphoreTake(mutex, portMAX_DELAY);
+ if (list.size() >= capacity) {
+ xSemaphoreGive(mutex);
+ return false;
+ }
+
+ list.push_back(data);
+ xSemaphoreGive(mutex);
+
+ // increase number of available inferences to pop
+ xSemaphoreGive(size);
+ return true;
+ }
+
+ void pop(T &data) {
+ // decrease the number of available inferences to pop
+ xSemaphoreTake(size, portMAX_DELAY);
+
+ xSemaphoreTake(mutex, portMAX_DELAY);
+ data = list.front();
+ list.pop_front();
+ xSemaphoreGive(mutex);
+ }
+
+ bool erase(Predicate pred) {
+ // let's optimistically assume we are removing an inference, so decrease pop
+ if (pdFALSE == xSemaphoreTake(size, 0)) {
+ // if there are no inferences return immediately
+ return false;
+ }
+
+ xSemaphoreTake(mutex, portMAX_DELAY);
+ auto found = std::find_if(list.begin(), list.end(), pred);
+ bool erased = found != list.end();
+ if (erased) {
+ list.erase(found);
+ }
+ xSemaphoreGive(mutex);
+
+ if (!erased) {
+ // no inference erased, so let's put the size count back
+ xSemaphoreGive(size);
+ }
+
+ return erased;
+ }
+
+private:
+ std::list<T> list;
+
+ SemaphoreHandle_t mutex;
+ SemaphoreHandle_t size;
+};
+
class IncomingMessageHandler {
public:
IncomingMessageHandler(EthosU::ethosu_core_queue &inputMessageQueue,
EthosU::ethosu_core_queue &outputMessageQueue,
Mailbox::Mailbox &mailbox,
- QueueHandle_t inferenceInputQueue,
+ std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
QueueHandle_t inferenceOutputQueue,
SemaphoreHandle_t messageNotify);
+
void run();
private:
@@ -66,7 +144,7 @@ private:
MessageQueue::QueueImpl outputMessageQueue;
Mailbox::Mailbox &mailbox;
InferenceProcess::InferenceParser parser;
- QueueHandle_t inferenceInputQueue;
+ std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
QueueHandle_t inferenceOutputQueue;
SemaphoreHandle_t messageNotify;
EthosU::ethosu_core_msg_capabilities_rsp capabilities;
@@ -76,7 +154,7 @@ class InferenceHandler {
public:
InferenceHandler(uint8_t *tensorArena,
size_t arenaSize,
- QueueHandle_t inferenceInputQueue,
+ std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue,
QueueHandle_t inferenceOutputQueue,
SemaphoreHandle_t messageNotify);
@@ -90,7 +168,7 @@ private:
friend void ::ethosu_inference_begin(struct ethosu_driver *drv, void *userArg);
friend void ::ethosu_inference_end(struct ethosu_driver *drv, void *userArg);
#endif
- QueueHandle_t inferenceInputQueue;
+ std::shared_ptr<Queue<EthosU::ethosu_core_inference_req>> inferenceInputQueue;
QueueHandle_t inferenceOutputQueue;
SemaphoreHandle_t messageNotify;
InferenceProcess::InferenceProcess inference;