From 72fa50bcf362643431c39642e5af30781714b2fc Mon Sep 17 00:00:00 2001 From: Kristofer Jonsson Date: Thu, 10 Sep 2020 13:26:41 +0200 Subject: Support inferences with multiple inputs and outputs Update inference process apis to support inferences with multiple inputs and multiple outputs. Update message process to handle new inference request message with an array of input- and output buffers. Change-Id: Ide0897385a1d829f58edace79140d01d8e3b85a3 --- .../message_process/src/message_process.cc | 76 +++++++++++++++------- 1 file changed, 54 insertions(+), 22 deletions(-) (limited to 'applications/message_process/src') diff --git a/applications/message_process/src/message_process.cc b/applications/message_process/src/message_process.cc index 2820275..b201f32 100644 --- a/applications/message_process/src/message_process.cc +++ b/applications/message_process/src/message_process.cc @@ -22,6 +22,9 @@ #include #include +using namespace std; +using namespace InferenceProcess; + namespace MessageProcess { QueueImpl::QueueImpl(ethosu_core_queue &queue) : queue(queue) {} @@ -112,7 +115,7 @@ bool QueueImpl::write(const uint32_t type, const void *src, uint32_t length) { MessageProcess::MessageProcess(ethosu_core_queue &in, ethosu_core_queue &out, - InferenceProcess::InferenceProcess &inferenceProcess) : + ::InferenceProcess::InferenceProcess &inferenceProcess) : queueIn(in), queueOut(out), inferenceProcess(inferenceProcess) {} @@ -165,24 +168,47 @@ bool MessageProcess::handleMessage() { ethosu_core_inference_req &req = data.inferenceReq; - printf("InferenceReq. network={0x%x, %u}, ifm={0x%x, %u}, ofm={0x%x, %u}\n", - req.network.ptr, - req.network.size, - req.ifm.ptr, - req.ifm.size, - req.ofm.ptr, - req.ofm.size, - req.user_arg); - - InferenceProcess::DataPtr networkModel(reinterpret_cast(req.network.ptr), req.network.size); - InferenceProcess::DataPtr ifm(reinterpret_cast(req.ifm.ptr), req.ifm.size); - InferenceProcess::DataPtr ofm(reinterpret_cast(req.ofm.ptr), req.ofm.size); - InferenceProcess::DataPtr expectedOutput; - InferenceProcess::InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1); + printf("InferenceReq. user_arg=0x%x, network={0x%x, %u}", req.user_arg, req.network.ptr, req.network.size); + + printf(", ifm_count=%u, ifm=[", req.ifm_count); + for (uint32_t i = 0; i < req.ifm_count; ++i) { + if (i > 0) { + printf(", "); + } + + printf("{0x%x, %u}", req.ifm[i].ptr, req.ifm[i].size); + } + printf("]"); + + printf(", ofm_count=%u, ofm=[", req.ofm_count); + for (uint32_t i = 0; i < req.ofm_count; ++i) { + if (i > 0) { + printf(", "); + } + + printf("{0x%x, %u}", req.ofm[i].ptr, req.ofm[i].size); + } + printf("]\n"); + + DataPtr networkModel(reinterpret_cast(req.network.ptr), req.network.size); + + vector ifm; + for (uint32_t i = 0; i < req.ifm_count; ++i) { + ifm.push_back(DataPtr(reinterpret_cast(req.ifm[i].ptr), req.ifm[i].size)); + } + + vector ofm; + for (uint32_t i = 0; i < req.ofm_count; ++i) { + ofm.push_back(DataPtr(reinterpret_cast(req.ofm[i].ptr), req.ofm[i].size)); + } + + vector expectedOutput; + + InferenceJob job("job", networkModel, ifm, ofm, expectedOutput, -1); bool failed = inferenceProcess.runJob(job); - sendInferenceRsp(data.inferenceReq.user_arg, job.output.size, failed); + sendInferenceRsp(data.inferenceReq.user_arg, job.output, failed); break; } default: @@ -198,15 +224,21 @@ void MessageProcess::sendPong() { } } -void MessageProcess::sendInferenceRsp(uint64_t userArg, size_t ofmSize, bool failed) { +void MessageProcess::sendInferenceRsp(uint64_t userArg, vector &ofm, bool failed) { ethosu_core_inference_rsp rsp; - rsp.user_arg = userArg; - rsp.ofm_size = ofmSize; - rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + rsp.user_arg = userArg; + rsp.ofm_count = ofm.size(); + rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + + for (size_t i = 0; i < ofm.size(); ++i) { + rsp.ofm_size[i] = ofm[i].size; + } - printf( - "Sending inference response. userArg=0x%llx, ofm_size=%u, status=%u\n", rsp.user_arg, rsp.ofm_size, rsp.status); + printf("Sending inference response. userArg=0x%llx, ofm_count=%u, status=%u\n", + rsp.user_arg, + rsp.ofm_count, + rsp.status); if (!queueOut.write(ETHOSU_CORE_MSG_INFERENCE_RSP, rsp)) { printf("Failed to write inference.\n"); -- cgit v1.2.1