diff options
author | Kristofer Jonsson <kristofer.jonsson@arm.com> | 2022-03-08 13:28:05 +0100 |
---|---|---|
committer | Kristofer Jonsson <kristofer.jonsson@arm.com> | 2022-03-11 15:07:50 +0100 |
commit | 585ce694dbbebfe5ba737fe94888343cb8976ac3 (patch) | |
tree | 14dde4fcf5873bbdc5453c6b121140b3c6a2ac8f /applications/message_handler/message_handler.cpp | |
parent | d188e902d7002ec748a3ac963db8b3b195fce499 (diff) | |
download | ethos-u-core-platform-585ce694dbbebfe5ba737fe94888343cb8976ac3.tar.gz |
Firmware resident model
Support building a network model into the firmware binary. The model
can be placed by a scatter file or linker script in for example a
low lateny high bandwidth memory like SRAM.
Change-Id: Ic742abed65e20f0da4ded7adefb039389b68b767
Diffstat (limited to 'applications/message_handler/message_handler.cpp')
-rw-r--r-- | applications/message_handler/message_handler.cpp | 168 |
1 files changed, 139 insertions, 29 deletions
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp index f95d5f6..e530712 100644 --- a/applications/message_handler/message_handler.cpp +++ b/applications/message_handler/message_handler.cpp @@ -32,15 +32,110 @@ #include <cstring> #include <inttypes.h> +#define XSTRINGIFY(src) #src +#define STRINGIFY(src) XSTRINGIFY(src) + using namespace EthosU; using namespace MessageQueue; +/**************************************************************************** + * Models + ****************************************************************************/ + +namespace { +#if defined(__has_include) + +#if __has_include(STRINGIFY(MODEL_0)) +namespace Model0 { +#include STRINGIFY(MODEL_0) +} +#endif + +#if __has_include(STRINGIFY(MODEL_1)) +namespace Model1 { +#include STRINGIFY(MODEL_1) +} +#endif + +#if __has_include(STRINGIFY(MODEL_2)) +namespace Model2 { +#include STRINGIFY(MODEL_2) +} +#endif + +#if __has_include(STRINGIFY(MODEL_3)) +namespace Model3 { +#include STRINGIFY(MODEL_3) +} +#endif + +#endif +} // namespace + namespace MessageHandler { /**************************************************************************** * IncomingMessageHandler ****************************************************************************/ +namespace { +bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) { + data = reinterpret_cast<void *>(buffer.ptr); + size = buffer.size; + return false; +} + +bool getNetwork(const uint32_t index, void *&data, size_t &size) { + switch (index) { +#if __has_include(STRINGIFY(MODEL_0)) + case 0: + data = reinterpret_cast<void *>(Model0::networkModel); + size = sizeof(Model0::networkModel); + break; +#endif + +#if __has_include(STRINGIFY(MODEL_1)) + case 1: + data = reinterpret_cast<void *>(Model1::networkModel); + size = sizeof(Model1::networkModel); + break; +#endif + +#if __has_include(STRINGIFY(MODEL_2)) + case 2: + data = reinterpret_cast<void *>(Model2::networkModel); + size = sizeof(Model2::networkModel); + break; +#endif + +#if __has_include(STRINGIFY(MODEL_3)) + case 3: + data = reinterpret_cast<void *>(Model3::networkModel); + size = sizeof(Model3::networkModel); + break; +#endif + + default: + printf("Error: Network model index out of range. index=%u\n", index); + return true; + } + + return false; +} + +bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &size) { + switch (buffer.type) { + case ETHOSU_CORE_NETWORK_BUFFER: + return getNetwork(buffer.buffer, data, size); + case ETHOSU_CORE_NETWORK_INDEX: + return getNetwork(buffer.index, data, size); + default: + printf("Error: Unsupported network model type. type=%u\n", buffer.type); + return true; + } +} +}; // namespace + IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue, Mailbox::Mailbox &_mailbox, QueueHandle_t _inferenceQueue, @@ -151,10 +246,17 @@ bool IncomingMessageHandler::handleMessage() { break; } - printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network={0x%" PRIx32 ", %" PRIu32 "}, \n", + printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network_type=%" PRIu32 ", ", inference.user_arg, - inference.network.ptr, - inference.network.size); + inference.network.type); + + if (inference.network.type == ETHOSU_CORE_NETWORK_BUFFER) { + printf("network.buffer={0x%" PRIx32 ", %" PRIu32 "},\n", + inference.network.buffer.ptr, + inference.network.buffer.size); + } else { + printf("network.index=%" PRIu32 ",\n", inference.network.index); + } printf("ifm_count=%" PRIu32 ", ifm=[", inference.ifm_count); for (uint32_t i = 0; i < inference.ifm_count; ++i) { @@ -228,43 +330,32 @@ void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_ currentRsp = &rsp; /* - * Setup inference job + * Run inference */ - InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size); + InferenceProcess::InferenceJob job; + bool failed = getInferenceJob(req, job); - std::vector<InferenceProcess::DataPtr> ifm; - for (uint32_t i = 0; i < req.ifm_count; ++i) { - ifm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size)); + if (!failed) { + job.invalidate(); + failed = inference.runJob(job); + job.clean(); } - std::vector<InferenceProcess::DataPtr> ofm; - for (uint32_t i = 0; i < req.ofm_count; ++i) { - ofm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size)); - } - - InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, {}, 0, this); - - /* - * Run inference - */ - - job.invalidate(); - bool failed = inference.runJob(job); - job.clean(); - /* * Print PMU counters */ - const int numEvents = std::min(static_cast<int>(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX); + if (!failed) { + const int numEvents = std::min(static_cast<int>(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX); - for (int i = 0; i < numEvents; i++) { - printf("ethosu_pmu_cntr%d : %" PRIu32 "\n", i, rsp.pmu_event_count[i]); - } + for (int i = 0; i < numEvents; i++) { + printf("ethosu_pmu_cntr%d : %" PRIu32 "\n", i, rsp.pmu_event_count[i]); + } - if (rsp.pmu_cycle_counter_enable) { - printf("ethosu_pmu_cycle_cntr : %" PRIu64 " cycles\n", rsp.pmu_cycle_counter_count); + if (rsp.pmu_cycle_counter_enable) { + printf("ethosu_pmu_cycle_cntr : %" PRIu64 " cycles\n", rsp.pmu_cycle_counter_count); + } } /* @@ -283,6 +374,25 @@ void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_ currentRsp = nullptr; } +bool InferenceHandler::getInferenceJob(const ethosu_core_inference_req &req, InferenceProcess::InferenceJob &job) { + bool failed = getNetwork(req.network, job.networkModel.data, job.networkModel.size); + if (failed) { + return true; + } + + for (uint32_t i = 0; i < req.ifm_count; ++i) { + job.input.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size)); + } + + for (uint32_t i = 0; i < req.ofm_count; ++i) { + job.output.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size)); + } + + job.externalContext = this; + + return false; +} + /**************************************************************************** * OutgoingMessageHandler ****************************************************************************/ |