diff options
Diffstat (limited to 'applications/message_handler/message_handler.cpp')
-rw-r--r-- | applications/message_handler/message_handler.cpp | 406 |
1 files changed, 406 insertions, 0 deletions
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp new file mode 100644 index 0000000..7401546 --- /dev/null +++ b/applications/message_handler/message_handler.cpp @@ -0,0 +1,406 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "message_handler.hpp" + +#include "cmsis_compiler.h" + +#ifdef ETHOSU +#include <ethosu_driver.h> +#endif + +#include "FreeRTOS.h" +#include "queue.h" +#include "semphr.h" + +#include <cstring> +#include <inttypes.h> + +using namespace EthosU; +using namespace MessageQueue; + +namespace MessageHandler { + +/**************************************************************************** + * IncomingMessageHandler + ****************************************************************************/ + +IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue, + Mailbox::Mailbox &_mailbox, + QueueHandle_t _inferenceQueue, + QueueHandle_t _outputQueue) : + messageQueue(_messageQueue), + mailbox(_mailbox), inferenceQueue(_inferenceQueue), outputQueue(_outputQueue) { + mailbox.registerCallback(handleIrq, reinterpret_cast<void *>(this)); + semaphore = xSemaphoreCreateBinary(); +} + +void IncomingMessageHandler::run() { + while (true) { + // Wait for event + xSemaphoreTake(semaphore, portMAX_DELAY); + + // Handle all messages in queue + while (handleMessage()) {} + } +} + +void IncomingMessageHandler::handleIrq(void *userArg) { + IncomingMessageHandler *_this = reinterpret_cast<IncomingMessageHandler *>(userArg); + xSemaphoreGive(_this->semaphore); +} + +void IncomingMessageHandler::queueErrorAndResetQueue(ethosu_core_msg_err_type type, const char *message) { + OutputMessage msg(ETHOSU_CORE_MSG_ERR); + msg.data.error.type = type; + + for (size_t i = 0; i < sizeof(msg.data.error.msg) && message[i]; i++) { + msg.data.error.msg[i] = message[i]; + } + + xQueueSend(outputQueue, &msg, portMAX_DELAY); + messageQueue.reset(); +} + +bool IncomingMessageHandler::handleMessage() { + struct ethosu_core_msg msg; + + if (messageQueue.available() == 0) { + return false; + } + + // Read msg header + // Only process a complete message header, else send error message + // and reset queue + if (!messageQueue.read(msg)) { + queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_SIZE, "Failed to read a complete header"); + return false; + } + + printf("Msg: header magic=%" PRIX32 ", type=%" PRIu32 ", length=%" PRIu32 "\n", msg.magic, msg.type, msg.length); + + if (msg.magic != ETHOSU_CORE_MSG_MAGIC) { + printf("Msg: Invalid Magic\n"); + queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_MAGIC, "Invalid magic"); + return false; + } + + switch (msg.type) { + case ETHOSU_CORE_MSG_PING: { + printf("Msg: Ping\n"); + + OutputMessage message(ETHOSU_CORE_MSG_PONG); + xQueueSend(outputQueue, &message, portMAX_DELAY); + break; + } + case ETHOSU_CORE_MSG_ERR: { + ethosu_core_msg_err error; + + if (!messageQueue.read(error)) { + printf("ERROR: Msg: Failed to receive error message\n"); + } else { + printf("Msg: Received an error response, type=%" PRIu32 ", msg=\"%s\"\n", error.type, error.msg); + } + + messageQueue.reset(); + return false; + } + case ETHOSU_CORE_MSG_VERSION_REQ: { + printf("Msg: Version request\n"); + + OutputMessage message(ETHOSU_CORE_MSG_VERSION_RSP); + xQueueSend(outputQueue, &message, portMAX_DELAY); + break; + } + case ETHOSU_CORE_MSG_CAPABILITIES_REQ: { + ethosu_core_capabilities_req capabilities; + + if (!messageQueue.read(capabilities)) { + queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "CapabilitiesReq. Failed to read payload"); + break; + } + + printf("Msg: Capabilities request.user_arg=0x%" PRIx64 "\n", capabilities.user_arg); + + OutputMessage message(ETHOSU_CORE_MSG_CAPABILITIES_RSP); + message.data.userArg = capabilities.user_arg; + xQueueSend(outputQueue, &message, portMAX_DELAY); + break; + } + case ETHOSU_CORE_MSG_INFERENCE_REQ: { + ethosu_core_inference_req inference; + + if (!messageQueue.read(inference)) { + queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "InferenceReq. Failed to read payload"); + break; + } + + printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network={0x%" PRIx32 ", %" PRIu32 "}\n", + inference.user_arg, + inference.network.ptr, + inference.network.size); + + printf(", ifm_count=%" PRIu32 ", ifm=[", inference.ifm_count); + for (uint32_t i = 0; i < inference.ifm_count; ++i) { + if (i > 0) { + printf(", "); + } + + printf("{0x%" PRIx32 ", %" PRIu32 "}", inference.ifm[i].ptr, inference.ifm[i].size); + } + printf("]"); + + printf(", ofm_count=%" PRIu32 ", ofm=[", inference.ofm_count); + for (uint32_t i = 0; i < inference.ofm_count; ++i) { + if (i > 0) { + printf(", "); + } + + printf("{0x%" PRIx32 ", %" PRIu32 "}", inference.ofm[i].ptr, inference.ofm[i].size); + } + printf("]\n"); + + xQueueSend(inferenceQueue, &inference, portMAX_DELAY); + break; + } + default: { + char errMsg[128]; + + snprintf(&errMsg[0], + sizeof(errMsg), + "Msg: Unknown type: %" PRIu32 " with payload length %" PRIu32 " bytes\n", + msg.type, + msg.length); + + queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_UNSUPPORTED_TYPE, errMsg); + + return false; + } + } + + return true; +} + +/**************************************************************************** + * InferenceHandler + ****************************************************************************/ + +InferenceHandler::InferenceHandler(uint8_t *tensorArena, + size_t arenaSize, + QueueHandle_t _inferenceQueue, + QueueHandle_t _outputQueue) : + inferenceQueue(_inferenceQueue), + outputQueue(_outputQueue), inference(tensorArena, arenaSize) {} + +void InferenceHandler::run() { + while (true) { + ethosu_core_inference_req req; + + if (pdTRUE != xQueueReceive(inferenceQueue, &req, portMAX_DELAY)) { + continue; + } + + OutputMessage msg(ETHOSU_CORE_MSG_INFERENCE_RSP); + runInference(req, msg.data.inference); + + xQueueSend(outputQueue, &msg, portMAX_DELAY); + } +} + +void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_inference_rsp &rsp) { + /* + * Setup inference job + */ + + InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size); + + std::vector<InferenceProcess::DataPtr> ifm; + for (uint32_t i = 0; i < req.ifm_count; ++i) { + ifm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size)); + } + + std::vector<InferenceProcess::DataPtr> ofm; + for (uint32_t i = 0; i < req.ofm_count; ++i) { + ofm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size)); + } + + std::vector<InferenceProcess::DataPtr> expectedOutput; + + std::vector<uint8_t> pmuEventConfig(ETHOSU_CORE_PMU_MAX); + for (uint32_t i = 0; i < ETHOSU_CORE_PMU_MAX; i++) { + pmuEventConfig[i] = req.pmu_event_config[i]; + } + + InferenceProcess::InferenceJob job( + "job", networkModel, ifm, ofm, expectedOutput, -1, pmuEventConfig, req.pmu_cycle_counter_enable); + + /* + * Run inference + */ + + job.invalidate(); + bool failed = inference.runJob(job); + job.clean(); + + /* + * Send inference response + */ + + rsp.user_arg = req.user_arg; + rsp.ofm_count = job.output.size(); + rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + + for (size_t i = 0; i < job.output.size(); ++i) { + rsp.ofm_size[i] = job.output[i].size; + } + + for (size_t i = 0; i < job.pmuEventConfig.size(); i++) { + rsp.pmu_event_config[i] = job.pmuEventConfig[i]; + } + + for (size_t i = 0; i < job.pmuEventCount.size(); i++) { + rsp.pmu_event_count[i] = job.pmuEventCount[i]; + } + + rsp.pmu_cycle_counter_enable = job.pmuCycleCounterEnable; + rsp.pmu_cycle_counter_count = job.pmuCycleCounterCount; +} + +/**************************************************************************** + * OutgoingMessageHandler + ****************************************************************************/ + +OutgoingMessageHandler::OutgoingMessageHandler(ethosu_core_queue &_messageQueue, + Mailbox::Mailbox &_mailbox, + QueueHandle_t _outputQueue) : + messageQueue(_messageQueue), + mailbox(_mailbox), outputQueue(_outputQueue) { + readCapabilties(capabilities); +} + +void OutgoingMessageHandler::run() { + while (true) { + OutputMessage message; + if (pdTRUE != xQueueReceive(outputQueue, &message, portMAX_DELAY)) { + continue; + } + + switch (message.type) { + case ETHOSU_CORE_MSG_INFERENCE_RSP: + sendInferenceRsp(message.data.inference); + break; + case ETHOSU_CORE_MSG_CAPABILITIES_RSP: + sendCapabilitiesRsp(message.data.userArg); + break; + case ETHOSU_CORE_MSG_VERSION_RSP: + sendVersionRsp(); + break; + case ETHOSU_CORE_MSG_PONG: + sendPong(); + break; + case ETHOSU_CORE_MSG_ERR: + sendErrorRsp(message.data.error); + break; + default: + printf("Dropping unknown outcome of type %d\n", message.type); + break; + } + } +} + +void OutgoingMessageHandler::sendPong() { + if (!messageQueue.write(ETHOSU_CORE_MSG_PONG)) { + printf("ERROR: Msg: Failed to write pong response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void OutgoingMessageHandler::sendVersionRsp() { + ethosu_core_msg_version version = { + ETHOSU_CORE_MSG_VERSION_MAJOR, + ETHOSU_CORE_MSG_VERSION_MINOR, + ETHOSU_CORE_MSG_VERSION_PATCH, + 0, + }; + + if (!messageQueue.write(ETHOSU_CORE_MSG_VERSION_RSP, version)) { + printf("ERROR: Failed to write version response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void OutgoingMessageHandler::sendCapabilitiesRsp(uint64_t userArg) { + capabilities.user_arg = userArg; + + if (!messageQueue.write(ETHOSU_CORE_MSG_CAPABILITIES_RSP, capabilities)) { + printf("ERROR: Failed to write capabilities response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void OutgoingMessageHandler::sendInferenceRsp(ethosu_core_inference_rsp &inference) { + if (!messageQueue.write(ETHOSU_CORE_MSG_INFERENCE_RSP, inference)) { + printf("ERROR: Msg: Failed to write inference response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void OutgoingMessageHandler::sendErrorRsp(ethosu_core_msg_err &error) { + printf("ERROR: Msg: \"%s\"\n", error.msg); + + if (!messageQueue.write(ETHOSU_CORE_MSG_ERR, error)) { + printf("ERROR: Msg: Failed to write error response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + +void OutgoingMessageHandler::readCapabilties(ethosu_core_msg_capabilities_rsp &rsp) { + rsp = {0}; + +#ifdef ETHOSU + struct ethosu_driver_version version; + ethosu_get_driver_version(&version); + + struct ethosu_hw_info info; + struct ethosu_driver *drv = ethosu_reserve_driver(); + ethosu_get_hw_info(drv, &info); + ethosu_release_driver(drv); + + rsp.user_arg = 0; + rsp.version_status = info.version.version_status; + rsp.version_minor = info.version.version_minor; + rsp.version_major = info.version.version_major; + rsp.product_major = info.version.product_major; + rsp.arch_patch_rev = info.version.arch_patch_rev; + rsp.arch_minor_rev = info.version.arch_minor_rev; + rsp.arch_major_rev = info.version.arch_major_rev; + rsp.driver_patch_rev = version.patch; + rsp.driver_minor_rev = version.minor; + rsp.driver_major_rev = version.major; + rsp.macs_per_cc = info.cfg.macs_per_cc; + rsp.cmd_stream_version = info.cfg.cmd_stream_version; + rsp.custom_dma = info.cfg.custom_dma; +#endif +} + +} // namespace MessageHandler |