diff options
author | Kristofer Jonsson <kristofer.jonsson@arm.com> | 2020-09-10 13:26:41 +0200 |
---|---|---|
committer | Kristofer Jonsson <kristofer.jonsson@arm.com> | 2020-09-15 13:27:50 +0200 |
commit | 72fa50bcf362643431c39642e5af30781714b2fc (patch) | |
tree | 4c8234b1f5a76d898991379fcdd6441eff3d18b0 /applications/message_process | |
parent | 98e379c83dd24619752e72e7aefdc15484813652 (diff) | |
download | ethos-u-core-software-72fa50bcf362643431c39642e5af30781714b2fc.tar.gz |
Support inferences with multiple inputs and outputs
Update inference process apis to support inferences with multiple inputs
and multiple outputs.
Update message process to handle new inference request message with an
array of input- and output buffers.
Change-Id: Ide0897385a1d829f58edace79140d01d8e3b85a3
Diffstat (limited to 'applications/message_process')
-rw-r--r-- | applications/message_process/include/message_process.hpp | 3 | ||||
-rw-r--r-- | applications/message_process/src/message_process.cc | 76 |
2 files changed, 56 insertions, 23 deletions
diff --git a/applications/message_process/include/message_process.hpp b/applications/message_process/include/message_process.hpp index 8044f7c..51f474d 100644 --- a/applications/message_process/include/message_process.hpp +++ b/applications/message_process/include/message_process.hpp @@ -24,6 +24,7 @@ #include <cstddef> #include <cstdio> +#include <vector> namespace MessageProcess { @@ -77,7 +78,7 @@ public: void handleIrq(); bool handleMessage(); void sendPong(); - void sendInferenceRsp(uint64_t userArg, size_t ofmSize, bool failed); + void sendInferenceRsp(uint64_t userArg, std::vector<InferenceProcess::DataPtr> &ofm, bool failed); private: QueueImpl queueIn; diff --git a/applications/message_process/src/message_process.cc b/applications/message_process/src/message_process.cc index 2820275..b201f32 100644 --- a/applications/message_process/src/message_process.cc +++ b/applications/message_process/src/message_process.cc @@ -22,6 +22,9 @@ #include <cstdio> #include <cstring> +using namespace std; +using namespace InferenceProcess; + namespace MessageProcess { QueueImpl::QueueImpl(ethosu_core_queue &queue) : queue(queue) {} @@ -112,7 +115,7 @@ bool QueueImpl::write(const uint32_t type, const void *src, uint32_t length) { MessageProcess::MessageProcess(ethosu_core_queue &in, ethosu_core_queue &out, - InferenceProcess::InferenceProcess &inferenceProcess) : + ::InferenceProcess::InferenceProcess &inferenceProcess) : queueIn(in), queueOut(out), inferenceProcess(inferenceProcess) {} @@ -165,24 +168,47 @@ bool MessageProcess::handleMessage() { ethosu_core_inference_req &req = data.inferenceReq; - printf("InferenceReq. network={0x%x, %u}, ifm={0x%x, %u}, ofm={0x%x, %u}\n", - req.network.ptr, - req.network.size, - req.ifm.ptr, - req.ifm.size, - req.ofm.ptr, - req.ofm.size, - req.user_arg); - - InferenceProcess::DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size); - InferenceProcess::DataPtr ifm(reinterpret_cast<void *>(req.ifm.ptr), req.ifm.size); - InferenceProcess::DataPtr ofm(reinterpret_cast<void *>(req.ofm.ptr), req.ofm.size); - InferenceProcess::DataPtr expectedOutput; - InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1); + printf("InferenceReq. user_arg=0x%x, network={0x%x, %u}", req.user_arg, req.network.ptr, req.network.size); + + printf(", ifm_count=%u, ifm=[", req.ifm_count); + for (uint32_t i = 0; i < req.ifm_count; ++i) { + if (i > 0) { + printf(", "); + } + + printf("{0x%x, %u}", req.ifm[i].ptr, req.ifm[i].size); + } + printf("]"); + + printf(", ofm_count=%u, ofm=[", req.ofm_count); + for (uint32_t i = 0; i < req.ofm_count; ++i) { + if (i > 0) { + printf(", "); + } + + printf("{0x%x, %u}", req.ofm[i].ptr, req.ofm[i].size); + } + printf("]\n"); + + DataPtr networkModel(reinterpret_cast<void *>(req.network.ptr), req.network.size); + + vector<DataPtr> ifm; + for (uint32_t i = 0; i < req.ifm_count; ++i) { + ifm.push_back(DataPtr(reinterpret_cast<void *>(req.ifm[i].ptr), req.ifm[i].size)); + } + + vector<DataPtr> ofm; + for (uint32_t i = 0; i < req.ofm_count; ++i) { + ofm.push_back(DataPtr(reinterpret_cast<void *>(req.ofm[i].ptr), req.ofm[i].size)); + } + + vector<DataPtr> expectedOutput; + + InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1); bool failed = inferenceProcess.runJob(job); - sendInferenceRsp(data.inferenceReq.user_arg, job.output.size, failed); + sendInferenceRsp(data.inferenceReq.user_arg, job.output, failed); break; } default: @@ -198,15 +224,21 @@ void MessageProcess::sendPong() { } } -void MessageProcess::sendInferenceRsp(uint64_t userArg, size_t ofmSize, bool failed) { +void MessageProcess::sendInferenceRsp(uint64_t userArg, vector<DataPtr> &ofm, bool failed) { ethosu_core_inference_rsp rsp; - rsp.user_arg = userArg; - rsp.ofm_size = ofmSize; - rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + rsp.user_arg = userArg; + rsp.ofm_count = ofm.size(); + rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + + for (size_t i = 0; i < ofm.size(); ++i) { + rsp.ofm_size[i] = ofm[i].size; + } - printf( - "Sending inference response. userArg=0x%llx, ofm_size=%u, status=%u\n", rsp.user_arg, rsp.ofm_size, rsp.status); + printf("Sending inference response. userArg=0x%llx, ofm_count=%u, status=%u\n", + rsp.user_arg, + rsp.ofm_count, + rsp.status); if (!queueOut.write(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)) { printf("Failed to write inference.\n"); |