aboutsummaryrefslogtreecommitdiff
path: root/applications/message_handler/message_handler.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'applications/message_handler/message_handler.cpp')
-rw-r--r--applications/message_handler/message_handler.cpp168
1 files changed, 139 insertions, 29 deletions
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp
index f95d5f6..e530712 100644
--- a/applications/message_handler/message_handler.cpp
+++ b/applications/message_handler/message_handler.cpp
@@ -32,15 +32,110 @@
#include <cstring>
#include <inttypes.h>
+#define XSTRINGIFY(src) #src
+#define STRINGIFY(src) XSTRINGIFY(src)
+
using namespace EthosU;
using namespace MessageQueue;
+/****************************************************************************
+ * Models
+ ****************************************************************************/
+
+namespace {
+#if defined(__has_include)
+
+#if __has_include(STRINGIFY(MODEL_0))
+namespace Model0 {
+#include STRINGIFY(MODEL_0)
+}
+#endif
+
+#if __has_include(STRINGIFY(MODEL_1))
+namespace Model1 {
+#include STRINGIFY(MODEL_1)
+}
+#endif
+
+#if __has_include(STRINGIFY(MODEL_2))
+namespace Model2 {
+#include STRINGIFY(MODEL_2)
+}
+#endif
+
+#if __has_include(STRINGIFY(MODEL_3))
+namespace Model3 {
+#include STRINGIFY(MODEL_3)
+}
+#endif
+
+#endif
+} // namespace
+
namespace MessageHandler {
/****************************************************************************
* IncomingMessageHandler
****************************************************************************/
+namespace {
+bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) {
+ data = reinterpret_cast<void *>(buffer.ptr);
+ size = buffer.size;
+ return false;
+}
+
+bool getNetwork(const uint32_t index, void *&data, size_t &size) {
+ switch (index) {
+#if __has_include(STRINGIFY(MODEL_0))
+ case 0:
+ data = reinterpret_cast<void *>(Model0::networkModel);
+ size = sizeof(Model0::networkModel);
+ break;
+#endif
+
+#if __has_include(STRINGIFY(MODEL_1))
+ case 1:
+ data = reinterpret_cast<void *>(Model1::networkModel);
+ size = sizeof(Model1::networkModel);
+ break;
+#endif
+
+#if __has_include(STRINGIFY(MODEL_2))
+ case 2:
+ data = reinterpret_cast<void *>(Model2::networkModel);
+ size = sizeof(Model2::networkModel);
+ break;
+#endif
+
+#if __has_include(STRINGIFY(MODEL_3))
+ case 3:
+ data = reinterpret_cast<void *>(Model3::networkModel);
+ size = sizeof(Model3::networkModel);
+ break;
+#endif
+
+ default:
+ printf("Error: Network model index out of range. index=%u\n", index);
+ return true;
+ }
+
+ return false;
+}
+
+bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &size) {
+ switch (buffer.type) {
+ case ETHOSU_CORE_NETWORK_BUFFER:
+ return getNetwork(buffer.buffer, data, size);
+ case ETHOSU_CORE_NETWORK_INDEX:
+ return getNetwork(buffer.index, data, size);
+ default:
+ printf("Error: Unsupported network model type. type=%u\n", buffer.type);
+ return true;
+ }
+}
+}; // namespace
+
IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue,
Mailbox::Mailbox &_mailbox,
QueueHandle_t _inferenceQueue,
@@ -151,10 +246,17 @@ bool IncomingMessageHandler::handleMessage() {
break;
}
- printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network={0x%" PRIx32 ", %" PRIu32 "}, \n",
+ printf("Msg: InferenceReq. user_arg=0x%" PRIx64 ", network_type=%" PRIu32 ", ",
inference.user_arg,
- inference.network.ptr,
- inference.network.size);
+ inference.network.type);
+
+ if (inference.network.type == ETHOSU_CORE_NETWORK_BUFFER) {
+ printf("network.buffer={0x%" PRIx32 ", %" PRIu32 "},\n",
+ inference.network.buffer.ptr,
+ inference.network.buffer.size);
+ } else {
+ printf("network.index=%" PRIu32 ",\n", inference.network.index);
+ }
printf("ifm_count=%" PRIu32 ", ifm=[", inference.ifm_count);
for (uint32_t i = 0; i < inference.ifm_count; ++i) {
@@ -228,43 +330,32 @@ void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_
currentRsp = &rsp;
/*
- * Setup inference job
+ * Run inference
*/
- InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size);
+ InferenceProcess::InferenceJob job;
+ bool failed = getInferenceJob(req, job);
- std::vector<InferenceProcess::DataPtr> ifm;
- for (uint32_t i = 0; i < req.ifm_count; ++i) {
- ifm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size));
+ if (!failed) {
+ job.invalidate();
+ failed = inference.runJob(job);
+ job.clean();
}
- std::vector<InferenceProcess::DataPtr> ofm;
- for (uint32_t i = 0; i < req.ofm_count; ++i) {
- ofm.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size));
- }
-
- InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, {}, 0, this);
-
- /*
- * Run inference
- */
-
- job.invalidate();
- bool failed = inference.runJob(job);
- job.clean();
-
/*
* Print PMU counters
*/
- const int numEvents = std::min(static_cast<int>(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX);
+ if (!failed) {
+ const int numEvents = std::min(static_cast<int>(ETHOSU_PMU_Get_NumEventCounters()), ETHOSU_CORE_PMU_MAX);
- for (int i = 0; i < numEvents; i++) {
- printf("ethosu_pmu_cntr%d : %" PRIu32 "\n", i, rsp.pmu_event_count[i]);
- }
+ for (int i = 0; i < numEvents; i++) {
+ printf("ethosu_pmu_cntr%d : %" PRIu32 "\n", i, rsp.pmu_event_count[i]);
+ }
- if (rsp.pmu_cycle_counter_enable) {
- printf("ethosu_pmu_cycle_cntr : %" PRIu64 " cycles\n", rsp.pmu_cycle_counter_count);
+ if (rsp.pmu_cycle_counter_enable) {
+ printf("ethosu_pmu_cycle_cntr : %" PRIu64 " cycles\n", rsp.pmu_cycle_counter_count);
+ }
}
/*
@@ -283,6 +374,25 @@ void InferenceHandler::runInference(ethosu_core_inference_req &req, ethosu_core_
currentRsp = nullptr;
}
+bool InferenceHandler::getInferenceJob(const ethosu_core_inference_req &req, InferenceProcess::InferenceJob &job) {
+ bool failed = getNetwork(req.network, job.networkModel.data, job.networkModel.size);
+ if (failed) {
+ return true;
+ }
+
+ for (uint32_t i = 0; i < req.ifm_count; ++i) {
+ job.input.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size));
+ }
+
+ for (uint32_t i = 0; i < req.ofm_count; ++i) {
+ job.output.push_back(InferenceProcess::DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size));
+ }
+
+ job.externalContext = this;
+
+ return false;
+}
+
/****************************************************************************
* OutgoingMessageHandler
****************************************************************************/