From 160001cdc876e2e8b9fdc4453dc72081c809ed2d Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Thu, 24 Mar 2022 15:38:27 +0100 Subject: Add support for cancelling enqueued inferences If an enqueued inference is cancelled it is simply removed from the queue. Both a cancel inference response with status done and a inference response with status cancelled are sent back. Also override the new operators to call in the FreeRTOS allocator instead of malloc/free. Change-Id: I243e678aa6b996084c9b9be1d1b00ffcecc75bc9 --- applications/message_handler/main.cpp | 77 ++++++++++++++++----- applications/message_handler/message_handler.cpp | 39 ++++++----- applications/message_handler/message_handler.hpp | 86 ++++++++++++++++++++++-- 3 files changed, 164 insertions(+), 38 deletions(-) (limited to 'applications') diff --git a/applications/message_handler/main.cpp b/applications/message_handler/main.cpp index dde5dc5..b527840 100644 --- a/applications/message_handler/main.cpp +++ b/applications/message_handler/main.cpp @@ -71,10 +71,6 @@ __attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outp namespace { -SemaphoreHandle_t messageNotify; -QueueHandle_t inferenceInputQueue; -QueueHandle_t inferenceOutputQueue; - // Mailbox driver #ifdef MHU_V2 Mailbox::MHUv2 mailbox(MHU_TX_BASE_ADDRESS, MHU_RX_BASE_ADDRESS); // txBase, rxBase @@ -86,6 +82,26 @@ Mailbox::MHUDummy mailbox; } // namespace +/**************************************************************************** + * Override new operators to call in FreeRTOS allocator + ****************************************************************************/ + +void *operator new(size_t size) { + return pvPortMalloc(size); +} + +void *operator new[](size_t size) { + return pvPortMalloc(size); +} + +void operator delete(void *ptr) { + vPortFree(ptr); +} + +void operator delete[](void *ptr) { + vPortFree(ptr); +} + /**************************************************************************** * Mutex & Semaphore ****************************************************************************/ @@ -150,6 +166,24 @@ int ethosu_semaphore_give(void *sem) { * Application ****************************************************************************/ +struct TaskParams { + TaskParams() : + messageNotify(xSemaphoreCreateBinary()), + inferenceInputQueue(std::make_shared>()), + inferenceOutputQueue(xQueueCreate(10, sizeof(ethosu_core_inference_rsp))) {} + + SemaphoreHandle_t messageNotify; + // Used to pass inference requests to the inference runner task + std::shared_ptr> inferenceInputQueue; + // Queue for message responses to the remote host + QueueHandle_t inferenceOutputQueue; +}; + +struct InferenceTaskParams { + TaskParams *taskParams; + uint8_t *arena; +}; + namespace { #ifdef MHU_IRQ @@ -160,21 +194,27 @@ void mailboxIrqHandler() { void inferenceTask(void *pvParameters) { printf("Starting inference task\n"); + InferenceTaskParams *params = reinterpret_cast(pvParameters); + + InferenceHandler process(params->arena, + arenaSize, + params->taskParams->inferenceInputQueue, + params->taskParams->inferenceOutputQueue, + params->taskParams->messageNotify); - uint8_t *arena = reinterpret_cast(pvParameters); - InferenceHandler process(arena, arenaSize, inferenceInputQueue, inferenceOutputQueue, messageNotify); process.run(); } -void messageTask(void *) { - printf("Starting input message task\n"); +void messageTask(void *pvParameters) { + printf("Starting message task\n"); + TaskParams *params = reinterpret_cast(pvParameters); IncomingMessageHandler process(*inputMessageQueue.toQueue(), *outputMessageQueue.toQueue(), mailbox, - inferenceInputQueue, - inferenceOutputQueue, - messageNotify); + params->inferenceInputQueue, + params->inferenceOutputQueue, + params->messageNotify); #ifdef MHU_IRQ // Register mailbox interrupt handler @@ -196,21 +236,22 @@ int main() { return 1; } - // Create message queues for inter process communication - messageNotify = xSemaphoreCreateBinary(); - inferenceInputQueue = xQueueCreate(10, sizeof(ethosu_core_inference_req)); - inferenceOutputQueue = xQueueCreate(10, sizeof(ethosu_core_inference_rsp)); + TaskParams taskParams; - // Task for handling incoming messages from the remote host - ret = xTaskCreate(messageTask, "messageTask", 1024, nullptr, 2, nullptr); + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr); if (ret != pdPASS) { printf("Failed to create 'messageTask'\n"); return ret; } + InferenceTaskParams infParams[NUM_PARALLEL_TASKS]; + // One inference task for each NPU for (size_t n = 0; n < NUM_PARALLEL_TASKS; n++) { - ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &tensorArena[n], 3, nullptr); + infParams[n].taskParams = &taskParams; + infParams[n].arena = reinterpret_cast(&tensorArena[n]); + ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &infParams[n], 3, nullptr); if (ret != pdPASS) { printf("Failed to create 'inferenceTask%d'\n", n); return ret; diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp index f06d056..f109dc8 100644 --- a/applications/message_handler/message_handler.cpp +++ b/applications/message_handler/message_handler.cpp @@ -138,12 +138,13 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s }; // namespace -IncomingMessageHandler::IncomingMessageHandler(EthosU::ethosu_core_queue &_inputMessageQueue, - EthosU::ethosu_core_queue &_outputMessageQueue, - Mailbox::Mailbox &_mailbox, - QueueHandle_t _inferenceInputQueue, - QueueHandle_t _inferenceOutputQueue, - SemaphoreHandle_t _messageNotify) : +IncomingMessageHandler::IncomingMessageHandler( + EthosU::ethosu_core_queue &_inputMessageQueue, + EthosU::ethosu_core_queue &_outputMessageQueue, + Mailbox::Mailbox &_mailbox, + std::shared_ptr> _inferenceInputQueue, + QueueHandle_t _inferenceOutputQueue, + SemaphoreHandle_t _messageNotify) : inputMessageQueue(_inputMessageQueue), outputMessageQueue(_outputMessageQueue), mailbox(_mailbox), inferenceInputQueue(_inferenceInputQueue), inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify) { @@ -166,7 +167,7 @@ void IncomingMessageHandler::handleIrq(void *userArg) { return; } IncomingMessageHandler *_this = reinterpret_cast(userArg); - xSemaphoreGive(_this->messageNotify); + xSemaphoreGiveFromISR(_this->messageNotify, nullptr); } void IncomingMessageHandler::sendErrorAndResetQueue(ethosu_core_msg_err_type type, const char *message) { @@ -287,7 +288,7 @@ bool IncomingMessageHandler::handleMessage() { } printf("]\n"); - if (pdTRUE != xQueueSend(inferenceInputQueue, &req, 0)) { + if (!inferenceInputQueue->push(req)) { printf("Msg: Inference queue full. Rejecting inference user_arg=0x%" PRIx64 "\n", req.user_arg); sendFailedInferenceRsp(req.user_arg, ETHOSU_CORE_STATUS_REJECTED); } @@ -303,7 +304,15 @@ bool IncomingMessageHandler::handleMessage() { req.user_arg, req.inference_handle); - sendCancelInferenceRsp(req.user_arg, ETHOSU_CORE_STATUS_ERROR); + bool found = + inferenceInputQueue->erase([req](auto &inf_req) { return inf_req.user_arg == req.inference_handle; }); + + // NOTE: send an inference response with status ABORTED if the inference has been droped from the queue + if (found) { + sendFailedInferenceRsp(req.inference_handle, ETHOSU_CORE_STATUS_ABORTED); + } + + sendCancelInferenceRsp(req.user_arg, found ? ETHOSU_CORE_STATUS_OK : ETHOSU_CORE_STATUS_ERROR); break; } case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: { @@ -450,22 +459,20 @@ void IncomingMessageHandler::readCapabilties(ethosu_core_msg_capabilities_rsp &r * InferenceHandler ****************************************************************************/ -InferenceHandler::InferenceHandler(uint8_t *_tensorArena, - size_t _arenaSize, - QueueHandle_t _inferenceInputQueue, +InferenceHandler::InferenceHandler(uint8_t *tensorArena, + size_t arenaSize, + std::shared_ptr> _inferenceInputQueue, QueueHandle_t _inferenceOutputQueue, SemaphoreHandle_t _messageNotify) : inferenceInputQueue(_inferenceInputQueue), - inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify), inference(_tensorArena, _arenaSize) {} + inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify), inference(tensorArena, arenaSize) {} void InferenceHandler::run() { ethosu_core_inference_req req; ethosu_core_inference_rsp rsp; while (true) { - if (pdTRUE != xQueueReceive(inferenceInputQueue, &req, portMAX_DELAY)) { - continue; - } + inferenceInputQueue->pop(req); runInference(req, rsp); diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp index fa79205..dd05059 100644 --- a/applications/message_handler/message_handler.hpp +++ b/applications/message_handler/message_handler.hpp @@ -24,6 +24,7 @@ #include "semphr.h" #include "message_queue.hpp" +#include #if defined(ETHOSU) #include #endif @@ -31,20 +32,97 @@ #include #include +#include #include #include +#include #include namespace MessageHandler { +template +class Queue { +public: + using Predicate = std::function; + + Queue() { + mutex = xSemaphoreCreateMutex(); + size = xSemaphoreCreateCounting(capacity, 0u); + + if (mutex == nullptr || size == nullptr) { + printf("Error: failed to allocate memory for inference queue\n"); + } + } + + ~Queue() { + vSemaphoreDelete(mutex); + vSemaphoreDelete(size); + } + + bool push(const T &data) { + xSemaphoreTake(mutex, portMAX_DELAY); + if (list.size() >= capacity) { + xSemaphoreGive(mutex); + return false; + } + + list.push_back(data); + xSemaphoreGive(mutex); + + // increase number of available inferences to pop + xSemaphoreGive(size); + return true; + } + + void pop(T &data) { + // decrease the number of available inferences to pop + xSemaphoreTake(size, portMAX_DELAY); + + xSemaphoreTake(mutex, portMAX_DELAY); + data = list.front(); + list.pop_front(); + xSemaphoreGive(mutex); + } + + bool erase(Predicate pred) { + // let's optimistically assume we are removing an inference, so decrease pop + if (pdFALSE == xSemaphoreTake(size, 0)) { + // if there are no inferences return immediately + return false; + } + + xSemaphoreTake(mutex, portMAX_DELAY); + auto found = std::find_if(list.begin(), list.end(), pred); + bool erased = found != list.end(); + if (erased) { + list.erase(found); + } + xSemaphoreGive(mutex); + + if (!erased) { + // no inference erased, so let's put the size count back + xSemaphoreGive(size); + } + + return erased; + } + +private: + std::list list; + + SemaphoreHandle_t mutex; + SemaphoreHandle_t size; +}; + class IncomingMessageHandler { public: IncomingMessageHandler(EthosU::ethosu_core_queue &inputMessageQueue, EthosU::ethosu_core_queue &outputMessageQueue, Mailbox::Mailbox &mailbox, - QueueHandle_t inferenceInputQueue, + std::shared_ptr> inferenceInputQueue, QueueHandle_t inferenceOutputQueue, SemaphoreHandle_t messageNotify); + void run(); private: @@ -66,7 +144,7 @@ private: MessageQueue::QueueImpl outputMessageQueue; Mailbox::Mailbox &mailbox; InferenceProcess::InferenceParser parser; - QueueHandle_t inferenceInputQueue; + std::shared_ptr> inferenceInputQueue; QueueHandle_t inferenceOutputQueue; SemaphoreHandle_t messageNotify; EthosU::ethosu_core_msg_capabilities_rsp capabilities; @@ -76,7 +154,7 @@ class InferenceHandler { public: InferenceHandler(uint8_t *tensorArena, size_t arenaSize, - QueueHandle_t inferenceInputQueue, + std::shared_ptr> inferenceInputQueue, QueueHandle_t inferenceOutputQueue, SemaphoreHandle_t messageNotify); @@ -90,7 +168,7 @@ private: friend void ::ethosu_inference_begin(struct ethosu_driver *drv, void *userArg); friend void ::ethosu_inference_end(struct ethosu_driver *drv, void *userArg); #endif - QueueHandle_t inferenceInputQueue; + std::shared_ptr> inferenceInputQueue; QueueHandle_t inferenceOutputQueue; SemaphoreHandle_t messageNotify; InferenceProcess::InferenceProcess inference; -- cgit v1.2.1