diff options
Diffstat (limited to 'applications/message_handler/message_handler.cpp')
-rw-r--r-- | applications/message_handler/message_handler.cpp | 417 |
1 files changed, 189 insertions, 228 deletions
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp index a9c7df7..a634c16 100644 --- a/applications/message_handler/message_handler.cpp +++ b/applications/message_handler/message_handler.cpp @@ -138,55 +138,76 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s }; // namespace -IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue, +IncomingMessageHandler::IncomingMessageHandler(EthosU::ethosu_core_queue &_inputMessageQueue, + EthosU::ethosu_core_queue &_outputMessageQueue, Mailbox::Mailbox &_mailbox, - QueueHandle_t _inferenceQueue, - QueueHandle_t _outputQueue) : - messageQueue(_messageQueue), - mailbox(_mailbox), inferenceQueue(_inferenceQueue), outputQueue(_outputQueue) { + QueueHandle_t _inferenceInputQueue, + QueueHandle_t _inferenceOutputQueue, + SemaphoreHandle_t _messageNotify) : + inputMessageQueue(_inputMessageQueue), + outputMessageQueue(_outputMessageQueue), mailbox(_mailbox), inferenceInputQueue(_inferenceInputQueue), + inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify) { mailbox.registerCallback(handleIrq, reinterpret_cast<void *>(this)); - semaphore = xSemaphoreCreateBinary(); + readCapabilties(capabilities); } void IncomingMessageHandler::run() { while (true) { // Wait for event - xSemaphoreTake(semaphore, portMAX_DELAY); + xSemaphoreTake(messageNotify, portMAX_DELAY); - // Handle all messages in queue - while (handleMessage()) {} + // Handle all inference outputs and all messages in queue + while (handleInferenceOutput() || handleMessage()) {} } } void IncomingMessageHandler::handleIrq(void *userArg) { + if (userArg == nullptr) { + return; + } IncomingMessageHandler *_this = reinterpret_cast<IncomingMessageHandler *>(userArg); - xSemaphoreGive(_this->semaphore); + xSemaphoreGive(_this->messageNotify); } -void IncomingMessageHandler::queueErrorAndResetQueue(ethosu_core_msg_err_type type, const char *message) { - OutputMessage msg(ETHOSU_CORE_MSG_ERR); - msg.data.error.type = type; +void IncomingMessageHandler::sendErrorAndResetQueue(ethosu_core_msg_err_type type, const char *message) { + ethosu_core_msg_err error; + error.type = type; - for (size_t i = 0; i < sizeof(msg.data.error.msg) && message[i]; i++) { - msg.data.error.msg[i] = message[i]; + for (size_t i = 0; i < sizeof(error.msg) && message[i]; i++) { + error.msg[i] = message[i]; } + printf("ERROR: Msg: \"%s\"\n", error.msg); - xQueueSend(outputQueue, &msg, portMAX_DELAY); - messageQueue.reset(); + if (!outputMessageQueue.write(ETHOSU_CORE_MSG_ERR, error)) { + printf("ERROR: Msg: Failed to write error response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } + inputMessageQueue.reset(); +} + +bool IncomingMessageHandler::handleInferenceOutput() { + struct ethosu_core_inference_rsp rsp; + if (pdTRUE != xQueueReceive(inferenceOutputQueue, &rsp, 0)) { + return false; + } + + sendInferenceRsp(rsp); + return true; } bool IncomingMessageHandler::handleMessage() { struct ethosu_core_msg msg; - if (messageQueue.available() == 0) { + if (inputMessageQueue.available() == 0) { return false; } // Read msg header // Only process a complete message header, else send error message // and reset queue - if (!messageQueue.read(msg)) { - queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_SIZE, "Failed to read a complete header"); + if (!inputMessageQueue.read(msg)) { + sendErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_SIZE, "Failed to read a complete header"); return false; } @@ -194,136 +215,104 @@ bool IncomingMessageHandler::handleMessage() { if (msg.magic != ETHOSU_CORE_MSG_MAGIC) { printf("Msg: Invalid Magic\n"); - queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_MAGIC, "Invalid magic"); + sendErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_MAGIC, "Invalid magic"); return false; } switch (msg.type) { case ETHOSU_CORE_MSG_PING: { printf("Msg: Ping\n"); - - OutputMessage message(ETHOSU_CORE_MSG_PONG); - xQueueSend(outputQueue, &message, portMAX_DELAY); + sendPong(); break; } case ETHOSU_CORE_MSG_ERR: { ethosu_core_msg_err error; - - if (!messageQueue.read(error)) { + if (!inputMessageQueue.read(error)) { printf("ERROR: Msg: Failed to receive error message\n"); } else { printf("Msg: Received an error response, type=%" PRIu32 ", msg=\"%s\"\n", error.type, error.msg); } - messageQueue.reset(); + inputMessageQueue.reset(); return false; } case ETHOSU_CORE_MSG_VERSION_REQ: { printf("Msg: Version request\n"); - - OutputMessage message(ETHOSU_CORE_MSG_VERSION_RSP); - xQueueSend(outputQueue, &message, portMAX_DELAY); + sendVersionRsp(); break; } case ETHOSU_CORE_MSG_CAPABILITIES_REQ: { - ethosu_core_capabilities_req capabilities; - - if (!messageQueue.read(capabilities)) { - queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "CapabilitiesReq. Failed to read payload"); + ethosu_core_capabilities_req req; + if (!inputMessageQueue.read(req)) { + sendErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "CapabilitiesReq. Failed to read payload"); break; } - printf("Msg: Capabilities request.user_arg=0x%" PRIx64 "\n", capabilities.user_arg); - - OutputMessage message(ETHOSU_CORE_MSG_CAPABILITIES_RSP); - message.data.userArg = capabilities.user_arg; - xQueueSend(outputQueue, &message, portMAX_DELAY); + printf("Msg: Capabilities request.user_arg=0x%" PRIx64 "\n", req.user_arg); + sendCapabilitiesRsp(req.user_arg); break; } case ETHOSU_CORE_MSG_INFERENCE_REQ: { - ethosu_core_inference_req inference; - - if (!messageQueue.read(inference)) { - queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "InferenceReq. Failed to read payload"); + ethosu_core_inference_req req; + if (!inputMessageQueue.read(req)) { + sendErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "InferenceReq. Failed to read payload"); break; } - printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network_type=%" PRIu32 ", ", - inference.user_arg, - inference.network.type); + printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network_type=%" PRIu32 ", ", req.user_arg, req.network.type); - if (inference.network.type == ETHOSU_CORE_NETWORK_BUFFER) { - printf("network.buffer={0x%" PRIx32 ", %" PRIu32 "},\n", - inference.network.buffer.ptr, - inference.network.buffer.size); + if (req.network.type == ETHOSU_CORE_NETWORK_BUFFER) { + printf("network.buffer={0x%" PRIx32 ", %" PRIu32 "},\n", req.network.buffer.ptr, req.network.buffer.size); } else { - printf("network.index=%" PRIu32 ",\n", inference.network.index); + printf("network.index=%" PRIu32 ",\n", req.network.index); } - printf("ifm_count=%" PRIu32 ", ifm=[", inference.ifm_count); - for (uint32_t i = 0; i < inference.ifm_count; ++i) { + printf("ifm_count=%" PRIu32 ", ifm=[", req.ifm_count); + for (uint32_t i = 0; i < req.ifm_count; ++i) { if (i > 0) { printf(", "); } - printf("{0x%" PRIx32 ", %" PRIu32 "}", inference.ifm[i].ptr, inference.ifm[i].size); + printf("{0x%" PRIx32 ", %" PRIu32 "}", req.ifm[i].ptr, req.ifm[i].size); } printf("]"); - printf(", ofm_count=%" PRIu32 ", ofm=[", inference.ofm_count); - for (uint32_t i = 0; i < inference.ofm_count; ++i) { + printf(", ofm_count=%" PRIu32 ", ofm=[", req.ofm_count); + for (uint32_t i = 0; i < req.ofm_count; ++i) { if (i > 0) { printf(", "); } - printf("{0x%" PRIx32 ", %" PRIu32 "}", inference.ofm[i].ptr, inference.ofm[i].size); + printf("{0x%" PRIx32 ", %" PRIu32 "}", req.ofm[i].ptr, req.ofm[i].size); } printf("]\n"); - xQueueSend(inferenceQueue, &inference, portMAX_DELAY); + if (pdTRUE != xQueueSend(inferenceInputQueue, &req, 0)) { + printf("Msg: Inference queue full. Rejecting inference user_arg=0x%" PRIx64 "\n", req.user_arg); + sendFailedInferenceRsp(req.user_arg, ETHOSU_CORE_STATUS_REJECTED); + } break; } case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: { ethosu_core_network_info_req req; - - if (!messageQueue.read(req)) { - queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "NetworkInfoReq. Failed to read payload"); + if (!inputMessageQueue.read(req)) { + sendErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "NetworkInfoReq. Failed to read payload"); break; } printf("Msg: NetworkInfoReq. user_arg=0x%" PRIx64 "\n", req.user_arg); - - OutputMessage message(ETHOSU_CORE_MSG_NETWORK_INFO_RSP); - ethosu_core_network_info_rsp &rsp = message.data.networkInfo; - rsp.user_arg = req.user_arg; - rsp.ifm_count = 0; - rsp.ofm_count = 0; - - void *buffer; - size_t size; - getNetwork(req.network, buffer, size); - bool failed = - parser.parseModel(buffer, - rsp.desc, - InferenceProcess::makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX), - InferenceProcess::makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX)); - - rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; - - xQueueSend(outputQueue, &message, portMAX_DELAY); + sendNetworkInfoRsp(req.user_arg, req.network); break; } default: { char errMsg[128]; - snprintf(&errMsg[0], sizeof(errMsg), "Msg: Unknown type: %" PRIu32 " with payload length %" PRIu32 " bytes\n", msg.type, msg.length); - queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_UNSUPPORTED_TYPE, errMsg); - + sendErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_UNSUPPORTED_TYPE, errMsg); return false; } } @@ -331,29 +320,134 @@ bool IncomingMessageHandler::handleMessage() { return true; } +void IncomingMessageHandler::sendPong() { + if (!outputMessageQueue.write(ETHOSU_CORE_MSG_PONG)) { + printf("ERROR: Msg: Failed to write pong response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void IncomingMessageHandler::sendVersionRsp() { + ethosu_core_msg_version version = { + ETHOSU_CORE_MSG_VERSION_MAJOR, + ETHOSU_CORE_MSG_VERSION_MINOR, + ETHOSU_CORE_MSG_VERSION_PATCH, + 0, + }; + + if (!outputMessageQueue.write(ETHOSU_CORE_MSG_VERSION_RSP, version)) { + printf("ERROR: Failed to write version response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void IncomingMessageHandler::sendCapabilitiesRsp(uint64_t userArg) { + capabilities.user_arg = userArg; + + if (!outputMessageQueue.write(ETHOSU_CORE_MSG_CAPABILITIES_RSP, capabilities)) { + printf("ERROR: Failed to write capabilities response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void IncomingMessageHandler::sendNetworkInfoRsp(uint64_t userArg, ethosu_core_network_buffer &network) { + ethosu_core_network_info_rsp rsp; + rsp.user_arg = userArg; + rsp.ifm_count = 0; + rsp.ofm_count = 0; + + void *buffer; + size_t size; + getNetwork(network, buffer, size); + bool failed = parser.parseModel(buffer, + rsp.desc, + InferenceProcess::makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX), + InferenceProcess::makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX)); + rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + + if (!outputMessageQueue.write(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, rsp)) { + printf("ERROR: Msg: Failed to write network info response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void IncomingMessageHandler::sendInferenceRsp(ethosu_core_inference_rsp &rsp) { + if (!outputMessageQueue.write(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)) { + printf("ERROR: Msg: Failed to write inference response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void IncomingMessageHandler::sendFailedInferenceRsp(uint64_t userArg, uint32_t status) { + ethosu_core_inference_rsp rsp; + rsp.user_arg = userArg; + rsp.status = status; + if (!outputMessageQueue.write(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)) { + printf("ERROR: Msg: Failed to write inference response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void IncomingMessageHandler::readCapabilties(ethosu_core_msg_capabilities_rsp &rsp) { + rsp = {}; + +#ifdef ETHOSU + struct ethosu_driver_version version; + ethosu_get_driver_version(&version); + + struct ethosu_hw_info info; + struct ethosu_driver *drv = ethosu_reserve_driver(); + ethosu_get_hw_info(drv, &info); + ethosu_release_driver(drv); + + rsp.user_arg = 0; + rsp.version_status = info.version.version_status; + rsp.version_minor = info.version.version_minor; + rsp.version_major = info.version.version_major; + rsp.product_major = info.version.product_major; + rsp.arch_patch_rev = info.version.arch_patch_rev; + rsp.arch_minor_rev = info.version.arch_minor_rev; + rsp.arch_major_rev = info.version.arch_major_rev; + rsp.driver_patch_rev = version.patch; + rsp.driver_minor_rev = version.minor; + rsp.driver_major_rev = version.major; + rsp.macs_per_cc = info.cfg.macs_per_cc; + rsp.cmd_stream_version = info.cfg.cmd_stream_version; + rsp.custom_dma = info.cfg.custom_dma; +#endif +} + /**************************************************************************** * InferenceHandler ****************************************************************************/ -InferenceHandler::InferenceHandler(uint8_t *tensorArena, - size_t arenaSize, - QueueHandle_t _inferenceQueue, - QueueHandle_t _outputQueue) : - inferenceQueue(_inferenceQueue), - outputQueue(_outputQueue), inference(tensorArena, arenaSize) {} +InferenceHandler::InferenceHandler(uint8_t *_tensorArena, + size_t _arenaSize, + QueueHandle_t _inferenceInputQueue, + QueueHandle_t _inferenceOutputQueue, + SemaphoreHandle_t _messageNotify) : + inferenceInputQueue(_inferenceInputQueue), + inferenceOutputQueue(_inferenceOutputQueue), messageNotify(_messageNotify), inference(_tensorArena, _arenaSize) {} void InferenceHandler::run() { - while (true) { - ethosu_core_inference_req req; + ethosu_core_inference_req req; + ethosu_core_inference_rsp rsp; - if (pdTRUE != xQueueReceive(inferenceQueue, &req, portMAX_DELAY)) { + while (true) { + if (pdTRUE != xQueueReceive(inferenceInputQueue, &req, portMAX_DELAY)) { continue; } - OutputMessage msg(ETHOSU_CORE_MSG_INFERENCE_RSP); - runInference(req, msg.data.inference); + runInference(req, rsp); - xQueueSend(outputQueue, &msg, portMAX_DELAY); + xQueueSend(inferenceOutputQueue, &rsp, portMAX_DELAY); + xSemaphoreGive(messageNotify); } } @@ -427,139 +521,6 @@ bool InferenceHandler::getInferenceJob(const ethosu_core_inference_req &req, Inf return false; } -/**************************************************************************** - * OutgoingMessageHandler - ****************************************************************************/ - -OutgoingMessageHandler::OutgoingMessageHandler(ethosu_core_queue &_messageQueue, - Mailbox::Mailbox &_mailbox, - QueueHandle_t _outputQueue) : - messageQueue(_messageQueue), - mailbox(_mailbox), outputQueue(_outputQueue) { - readCapabilties(capabilities); -} - -void OutgoingMessageHandler::run() { - while (true) { - OutputMessage message; - if (pdTRUE != xQueueReceive(outputQueue, &message, portMAX_DELAY)) { - continue; - } - - switch (message.type) { - case ETHOSU_CORE_MSG_INFERENCE_RSP: - sendInferenceRsp(message.data.inference); - break; - case ETHOSU_CORE_MSG_CAPABILITIES_RSP: - sendCapabilitiesRsp(message.data.userArg); - break; - case ETHOSU_CORE_MSG_VERSION_RSP: - sendVersionRsp(); - break; - case ETHOSU_CORE_MSG_PONG: - sendPong(); - break; - case ETHOSU_CORE_MSG_ERR: - sendErrorRsp(message.data.error); - break; - case ETHOSU_CORE_MSG_NETWORK_INFO_RSP: - sendNetworkInfoRsp(message.data.networkInfo); - break; - default: - printf("Dropping unknown outcome of type %d\n", message.type); - break; - } - } -} - -void OutgoingMessageHandler::sendPong() { - if (!messageQueue.write(ETHOSU_CORE_MSG_PONG)) { - printf("ERROR: Msg: Failed to write pong response. No mailbox message sent\n"); - } else { - mailbox.sendMessage(); - } -} - -void OutgoingMessageHandler::sendVersionRsp() { - ethosu_core_msg_version version = { - ETHOSU_CORE_MSG_VERSION_MAJOR, - ETHOSU_CORE_MSG_VERSION_MINOR, - ETHOSU_CORE_MSG_VERSION_PATCH, - 0, - }; - - if (!messageQueue.write(ETHOSU_CORE_MSG_VERSION_RSP, version)) { - printf("ERROR: Failed to write version response. No mailbox message sent\n"); - } else { - mailbox.sendMessage(); - } -} - -void OutgoingMessageHandler::sendCapabilitiesRsp(uint64_t userArg) { - capabilities.user_arg = userArg; - - if (!messageQueue.write(ETHOSU_CORE_MSG_CAPABILITIES_RSP, capabilities)) { - printf("ERROR: Failed to write capabilities response. No mailbox message sent\n"); - } else { - mailbox.sendMessage(); - } -} - -void OutgoingMessageHandler::sendInferenceRsp(ethosu_core_inference_rsp &inference) { - if (!messageQueue.write(ETHOSU_CORE_MSG_INFERENCE_RSP, inference)) { - printf("ERROR: Msg: Failed to write inference response. No mailbox message sent\n"); - } else { - mailbox.sendMessage(); - } -} - -void OutgoingMessageHandler::sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo) { - if (!messageQueue.write(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, networkInfo)) { - printf("ERROR: Msg: Failed to write network info response. No mailbox message sent\n"); - } else { - mailbox.sendMessage(); - } -} - -void OutgoingMessageHandler::sendErrorRsp(ethosu_core_msg_err &error) { - printf("ERROR: Msg: \"%s\"\n", error.msg); - - if (!messageQueue.write(ETHOSU_CORE_MSG_ERR, error)) { - printf("ERROR: Msg: Failed to write error response. No mailbox message sent\n"); - } else { - mailbox.sendMessage(); - } -} - -void OutgoingMessageHandler::readCapabilties(ethosu_core_msg_capabilities_rsp &rsp) { - rsp = {}; - -#ifdef ETHOSU - struct ethosu_driver_version version; - ethosu_get_driver_version(&version); - - struct ethosu_hw_info info; - struct ethosu_driver *drv = ethosu_reserve_driver(); - ethosu_get_hw_info(drv, &info); - ethosu_release_driver(drv); - - rsp.user_arg = 0; - rsp.version_status = info.version.version_status; - rsp.version_minor = info.version.version_minor; - rsp.version_major = info.version.version_major; - rsp.product_major = info.version.product_major; - rsp.arch_patch_rev = info.version.arch_patch_rev; - rsp.arch_minor_rev = info.version.arch_minor_rev; - rsp.arch_major_rev = info.version.arch_major_rev; - rsp.driver_patch_rev = version.patch; - rsp.driver_minor_rev = version.minor; - rsp.driver_major_rev = version.major; - rsp.macs_per_cc = info.cfg.macs_per_cc; - rsp.cmd_stream_version = info.cfg.cmd_stream_version; - rsp.custom_dma = info.cfg.custom_dma; -#endif -} - } // namespace MessageHandler #if defined(ETHOSU) |