aboutsummaryrefslogtreecommitdiff
path: root/applications/message_process
diff options
context:
space:
mode:
authorKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-10 13:26:41 +0200
committerKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-15 13:27:50 +0200
commit72fa50bcf362643431c39642e5af30781714b2fc (patch)
tree4c8234b1f5a76d898991379fcdd6441eff3d18b0 /applications/message_process
parent98e379c83dd24619752e72e7aefdc15484813652 (diff)
downloadethos-u-core-software-72fa50bcf362643431c39642e5af30781714b2fc.tar.gz
Support inferences with multiple inputs and outputs
Update inference process apis to support inferences with multiple inputs and multiple outputs. Update message process to handle new inference request message with an array of input- and output buffers. Change-Id: Ide0897385a1d829f58edace79140d01d8e3b85a3
Diffstat (limited to 'applications/message_process')
-rw-r--r--applications/message_process/include/message_process.hpp3
-rw-r--r--applications/message_process/src/message_process.cc76
2 files changed, 56 insertions, 23 deletions
diff --git a/applications/message_process/include/message_process.hpp b/applications/message_process/include/message_process.hpp
index 8044f7c..51f474d 100644
--- a/applications/message_process/include/message_process.hpp
+++ b/applications/message_process/include/message_process.hpp
@@ -24,6 +24,7 @@
#include <cstddef>
#include <cstdio>
+#include <vector>
namespace MessageProcess {
@@ -77,7 +78,7 @@ public:
void handleIrq();
bool handleMessage();
void sendPong();
- void sendInferenceRsp(uint64_t userArg, size_t ofmSize, bool failed);
+ void sendInferenceRsp(uint64_t userArg, std::vector<InferenceProcess::DataPtr> &ofm, bool failed);
private:
QueueImpl queueIn;
diff --git a/applications/message_process/src/message_process.cc b/applications/message_process/src/message_process.cc
index 2820275..b201f32 100644
--- a/applications/message_process/src/message_process.cc
+++ b/applications/message_process/src/message_process.cc
@@ -22,6 +22,9 @@
#include <cstdio>
#include <cstring>
+using namespace std;
+using namespace InferenceProcess;
+
namespace MessageProcess {
QueueImpl::QueueImpl(ethosu_core_queue &queue) : queue(queue) {}
@@ -112,7 +115,7 @@ bool QueueImpl::write(const uint32_t type, const void *src, uint32_t length) {
MessageProcess::MessageProcess(ethosu_core_queue &in,
ethosu_core_queue &out,
- InferenceProcess::InferenceProcess &inferenceProcess) :
+ ::InferenceProcess::InferenceProcess &inferenceProcess) :
queueIn(in),
queueOut(out), inferenceProcess(inferenceProcess) {}
@@ -165,24 +168,47 @@ bool MessageProcess::handleMessage() {
ethosu_core_inference_req &req = data.inferenceReq;
- printf("InferenceReq. network={0x%x, %u}, ifm={0x%x, %u}, ofm={0x%x, %u}\n",
- req.network.ptr,
- req.network.size,
- req.ifm.ptr,
- req.ifm.size,
- req.ofm.ptr,
- req.ofm.size,
- req.user_arg);
-
- InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size);
- InferenceProcess::DataPtr ifm(reinterpret_cast<void *>(req.ifm.ptr), req.ifm.size);
- InferenceProcess::DataPtr ofm(reinterpret_cast<void *>(req.ofm.ptr), req.ofm.size);
- InferenceProcess::DataPtr expectedOutput;
- InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1);
+ printf("InferenceReq. user_arg=0x%x, network={0x%x, %u}", req.user_arg, req.network.ptr, req.network.size);
+
+ printf(", ifm_count=%u, ifm=[", req.ifm_count);
+ for (uint32_t i = 0; i < req.ifm_count; ++i) {
+ if (i > 0) {
+ printf(", ");
+ }
+
+ printf("{0x%x, %u}", req.ifm[i].ptr, req.ifm[i].size);
+ }
+ printf("]");
+
+ printf(", ofm_count=%u, ofm=[", req.ofm_count);
+ for (uint32_t i = 0; i < req.ofm_count; ++i) {
+ if (i > 0) {
+ printf(", ");
+ }
+
+ printf("{0x%x, %u}", req.ofm[i].ptr, req.ofm[i].size);
+ }
+ printf("]\n");
+
+ DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size);
+
+ vector<DataPtr> ifm;
+ for (uint32_t i = 0; i < req.ifm_count; ++i) {
+ ifm.push_back(DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size));
+ }
+
+ vector<DataPtr> ofm;
+ for (uint32_t i = 0; i < req.ofm_count; ++i) {
+ ofm.push_back(DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size));
+ }
+
+ vector<DataPtr> expectedOutput;
+
+ InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1);
bool failed = inferenceProcess.runJob(job);
- sendInferenceRsp(data.inferenceReq.user_arg, job.output.size, failed);
+ sendInferenceRsp(data.inferenceReq.user_arg, job.output, failed);
break;
}
default:
@@ -198,15 +224,21 @@ void MessageProcess::sendPong() {
}
}
-void MessageProcess::sendInferenceRsp(uint64_t userArg, size_t ofmSize, bool failed) {
+void MessageProcess::sendInferenceRsp(uint64_t userArg, vector<DataPtr> &ofm, bool failed) {
ethosu_core_inference_rsp rsp;
- rsp.user_arg = userArg;
- rsp.ofm_size = ofmSize;
- rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
+ rsp.user_arg = userArg;
+ rsp.ofm_count = ofm.size();
+ rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
+
+ for (size_t i = 0; i < ofm.size(); ++i) {
+ rsp.ofm_size[i] = ofm[i].size;
+ }
- printf(
- "Sending inference response. userArg=0x%llx, ofm_size=%u, status=%u\n", rsp.user_arg, rsp.ofm_size, rsp.status);
+ printf("Sending inference response. userArg=0x%llx, ofm_count=%u, status=%u\n",
+ rsp.user_arg,
+ rsp.ofm_count,
+ rsp.status);
if (!queueOut.write(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)) {
printf("Failed to write inference.\n");