diff options
Diffstat (limited to 'applications/message_handler/main.cpp')
-rw-r--r-- | applications/message_handler/main.cpp | 77 |
1 files changed, 59 insertions, 18 deletions
diff --git a/applications/message_handler/main.cpp b/applications/message_handler/main.cpp index dde5dc5..b527840 100644 --- a/applications/message_handler/main.cpp +++ b/applications/message_handler/main.cpp @@ -71,10 +71,6 @@ __attribute__((section("ethosu_core_out_queue"))) MessageQueue::Queue<1000> outp namespace { -SemaphoreHandle_t messageNotify; -QueueHandle_t inferenceInputQueue; -QueueHandle_t inferenceOutputQueue; - // Mailbox driver #ifdef MHU_V2 Mailbox::MHUv2 mailbox(MHU_TX_BASE_ADDRESS, MHU_RX_BASE_ADDRESS); // txBase, rxBase @@ -87,6 +83,26 @@ Mailbox::MHUDummy mailbox; } // namespace /**************************************************************************** + * Override new operators to call in FreeRTOS allocator + ****************************************************************************/ + +void *operator new(size_t size) { + return pvPortMalloc(size); +} + +void *operator new[](size_t size) { + return pvPortMalloc(size); +} + +void operator delete(void *ptr) { + vPortFree(ptr); +} + +void operator delete[](void *ptr) { + vPortFree(ptr); +} + +/**************************************************************************** * Mutex & Semaphore ****************************************************************************/ @@ -150,6 +166,24 @@ int ethosu_semaphore_give(void *sem) { * Application ****************************************************************************/ +struct TaskParams { + TaskParams() : + messageNotify(xSemaphoreCreateBinary()), + inferenceInputQueue(std::make_shared<Queue<ethosu_core_inference_req>>()), + inferenceOutputQueue(xQueueCreate(10, sizeof(ethosu_core_inference_rsp))) {} + + SemaphoreHandle_t messageNotify; + // Used to pass inference requests to the inference runner task + std::shared_ptr<Queue<ethosu_core_inference_req>> inferenceInputQueue; + // Queue for message responses to the remote host + QueueHandle_t inferenceOutputQueue; +}; + +struct InferenceTaskParams { + TaskParams *taskParams; + uint8_t *arena; +}; + namespace { #ifdef MHU_IRQ @@ -160,21 +194,27 @@ void mailboxIrqHandler() { void inferenceTask(void *pvParameters) { printf("Starting inference task\n"); + InferenceTaskParams *params = reinterpret_cast<InferenceTaskParams *>(pvParameters); + + InferenceHandler process(params->arena, + arenaSize, + params->taskParams->inferenceInputQueue, + params->taskParams->inferenceOutputQueue, + params->taskParams->messageNotify); - uint8_t *arena = reinterpret_cast<uint8_t *>(pvParameters); - InferenceHandler process(arena, arenaSize, inferenceInputQueue, inferenceOutputQueue, messageNotify); process.run(); } -void messageTask(void *) { - printf("Starting input message task\n"); +void messageTask(void *pvParameters) { + printf("Starting message task\n"); + TaskParams *params = reinterpret_cast<TaskParams *>(pvParameters); IncomingMessageHandler process(*inputMessageQueue.toQueue(), *outputMessageQueue.toQueue(), mailbox, - inferenceInputQueue, - inferenceOutputQueue, - messageNotify); + params->inferenceInputQueue, + params->inferenceOutputQueue, + params->messageNotify); #ifdef MHU_IRQ // Register mailbox interrupt handler @@ -196,21 +236,22 @@ int main() { return 1; } - // Create message queues for inter process communication - messageNotify = xSemaphoreCreateBinary(); - inferenceInputQueue = xQueueCreate(10, sizeof(ethosu_core_inference_req)); - inferenceOutputQueue = xQueueCreate(10, sizeof(ethosu_core_inference_rsp)); + TaskParams taskParams; - // Task for handling incoming messages from the remote host - ret = xTaskCreate(messageTask, "messageTask", 1024, nullptr, 2, nullptr); + // Task for handling incoming /outgoing messages from the remote host + ret = xTaskCreate(messageTask, "messageTask", 1024, &taskParams, 2, nullptr); if (ret != pdPASS) { printf("Failed to create 'messageTask'\n"); return ret; } + InferenceTaskParams infParams[NUM_PARALLEL_TASKS]; + // One inference task for each NPU for (size_t n = 0; n < NUM_PARALLEL_TASKS; n++) { - ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &tensorArena[n], 3, nullptr); + infParams[n].taskParams = &taskParams; + infParams[n].arena = reinterpret_cast<uint8_t *>(&tensorArena[n]); + ret = xTaskCreate(inferenceTask, "inferenceTask", 8 * 1024, &infParams[n], 3, nullptr); if (ret != pdPASS) { printf("Failed to create 'inferenceTask%d'\n", n); return ret; |