From ac535f0387647b114826e921d23e68787f8a572b Mon Sep 17 00:00:00 2001 From: Kristofer Jonsson Date: Thu, 10 Mar 2022 11:08:39 +0100 Subject: Network info Add message for fetching meta data about built in network models. Change-Id: I757094c20848d4cb018db68b0455297bb03be463 --- applications/message_handler/CMakeLists.txt | 2 +- applications/message_handler/main.cpp | 4 +- applications/message_handler/message_handler.cpp | 184 ++++++++++++++++++++++- applications/message_handler/message_handler.hpp | 4 +- applications/message_handler/message_queue.cpp | 2 +- applications/message_handler/message_queue.hpp | 2 +- applications/message_handler/model_template.hpp | 2 +- 7 files changed, 192 insertions(+), 8 deletions(-) (limited to 'applications/message_handler') diff --git a/applications/message_handler/CMakeLists.txt b/applications/message_handler/CMakeLists.txt index 72d930f..27a4815 100644 --- a/applications/message_handler/CMakeLists.txt +++ b/applications/message_handler/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2022 Arm Limited. All rights reserved. +# Copyright (c) 2020-2022 Arm Limited. # # SPDX-License-Identifier: Apache-2.0 # diff --git a/applications/message_handler/main.cpp b/applications/message_handler/main.cpp index 9b36f84..8a36325 100644 --- a/applications/message_handler/main.cpp +++ b/applications/message_handler/main.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. All rights reserved. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * @@ -182,7 +182,7 @@ int main() { outputQueue = xQueueCreate(10, sizeof(OutputMessage)); // Task for handling incoming messages from the remote host - ret = xTaskCreate(inputMessageTask, "inputMessageTask", 512, nullptr, 2, nullptr); + ret = xTaskCreate(inputMessageTask, "inputMessageTask", 1024, nullptr, 2, nullptr); if (ret != pdPASS) { printf("Failed to create 'inputMessageTask'\n"); return ret; diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp index e530712..4b77389 100644 --- a/applications/message_handler/message_handler.cpp +++ b/applications/message_handler/message_handler.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * @@ -19,6 +19,7 @@ #include "message_handler.hpp" #include "cmsis_compiler.h" +#include "tensorflow/lite/schema/schema_generated.h" #ifdef ETHOSU #include @@ -31,6 +32,7 @@ #include #include +#include #define XSTRINGIFY(src) #src #define STRINGIFY(src) XSTRINGIFY(src) @@ -79,6 +81,36 @@ namespace MessageHandler { ****************************************************************************/ namespace { + +template +class Array { +public: + Array() = delete; + Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {} + + auto size() const { + return _size; + } + + auto capacity() const { + return _capacity; + } + + void push_back(const T &data) { + _data[_size++] = data; + } + +private: + T *const _data; + U &_size; + const size_t _capacity{}; +}; + +template +Array makeArray(T *const data, U &size, size_t capacity) { + return Array{data, size, capacity}; +} + bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) { data = reinterpret_cast(buffer.ptr); size = buffer.size; @@ -134,6 +166,119 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s return true; } } + +bool getShapeSize(const flatbuffers::Vector *shape, size_t &size) { + size = 1; + + if (shape == nullptr) { + printf("Warning: nullptr shape size.\n"); + return true; + } + + if (shape->Length() == 0) { + printf("Warning: shape zero length.\n"); + return true; + } + + for (auto it = shape->begin(); it != shape->end(); ++it) { + size *= *it; + } + + return false; +} + +bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) { + switch (type) { + case tflite::TensorType::TensorType_UINT8: + case tflite::TensorType::TensorType_INT8: + size = 1; + break; + case tflite::TensorType::TensorType_INT16: + size = 2; + break; + case tflite::TensorType::TensorType_INT32: + case tflite::TensorType::TensorType_FLOAT32: + size = 4; + break; + default: + printf("Warning: Unsupported tensor type\n"); + return true; + } + + return false; +} + +template +bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector *tensorMap, T &dims) { + if (subgraph == nullptr || tensorMap == nullptr) { + printf("Warning: nullptr subgraph or tensormap.\n"); + return true; + } + + if ((dims.capacity() - dims.size()) < tensorMap->size()) { + printf("Warning: tensormap size is larger than dimension capacity.\n"); + return true; + } + + for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) { + auto tensor = subgraph->tensors()->Get(*index); + size_t size; + size_t tensorSize; + + bool failed = getShapeSize(tensor->shape(), size); + if (failed) { + return true; + } + + failed = getTensorTypeSize(tensor->type(), tensorSize); + if (failed) { + return true; + } + + size *= tensorSize; + + if (size > 0) { + dims.push_back(size); + } + } + + return false; +} + +template +bool parseModel(const ethosu_core_network_buffer &buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) { + void *data; + size_t size; + bool failed = getNetwork(buffer, data, size); + if (failed) { + return true; + } + + // Create model handle + const tflite::Model *model = tflite::GetModel(reinterpret_cast(data)); + if (model->subgraphs() == nullptr) { + printf("Warning: nullptr subgraph\n"); + return true; + } + + strncpy(description, model->description()->c_str(), sizeof(description)); + + // Get input dimensions for first subgraph + auto *subgraph = *model->subgraphs()->begin(); + failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims); + if (failed) { + return true; + } + + // Get output dimensions for last subgraph + subgraph = *model->subgraphs()->rbegin(); + failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims); + if (failed) { + return true; + } + + return false; +} }; // namespace IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue, @@ -281,6 +426,32 @@ bool IncomingMessageHandler::handleMessage() { xQueueSend(inferenceQueue, &inference, portMAX_DELAY); break; } + case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: { + ethosu_core_network_info_req req; + + if (!messageQueue.read(req)) { + queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "NetworkInfoReq. Failed to read payload"); + break; + } + + printf("Msg: NetworkInfoReq. user_arg=0x%" PRIx64 "\n", req.user_arg); + + OutputMessage message(ETHOSU_CORE_MSG_NETWORK_INFO_RSP); + ethosu_core_network_info_rsp &rsp = message.data.networkInfo; + rsp.user_arg = req.user_arg; + rsp.ifm_count = 0; + rsp.ofm_count = 0; + + bool failed = parseModel(req.network, + rsp.desc, + makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX), + makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX)); + + rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + + xQueueSend(outputQueue, &message, portMAX_DELAY); + break; + } default: { char errMsg[128]; @@ -428,6 +599,9 @@ void OutgoingMessageHandler::run() { case ETHOSU_CORE_MSG_ERR: sendErrorRsp(message.data.error); break; + case ETHOSU_CORE_MSG_NETWORK_INFO_RSP: + sendNetworkInfoRsp(message.data.networkInfo); + break; default: printf("Dropping unknown outcome of type %d\n", message.type); break; @@ -476,6 +650,14 @@ void OutgoingMessageHandler::sendInferenceRsp(ethosu_core_inference_rsp &inferen } } +void OutgoingMessageHandler::sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo) { + if (!messageQueue.write(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, networkInfo)) { + printf("ERROR: Msg: Failed to write network info response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + void OutgoingMessageHandler::sendErrorRsp(ethosu_core_msg_err &error) { printf("ERROR: Msg: \"%s\"\n", error.msg); diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp index ee063de..90b1cd2 100644 --- a/applications/message_handler/message_handler.hpp +++ b/applications/message_handler/message_handler.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * @@ -81,6 +81,7 @@ struct OutputMessage { EthosU::ethosu_core_msg_type type; union { EthosU::ethosu_core_inference_rsp inference; + EthosU::ethosu_core_network_info_rsp networkInfo; EthosU::ethosu_core_msg_err error; uint64_t userArg; } data; @@ -99,6 +100,7 @@ private: void sendVersionRsp(); void sendCapabilitiesRsp(uint64_t userArg); void sendInferenceRsp(EthosU::ethosu_core_inference_rsp &inference); + void sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo); void readCapabilties(EthosU::ethosu_core_msg_capabilities_rsp &rsp); MessageQueue::QueueImpl messageQueue; diff --git a/applications/message_handler/message_queue.cpp b/applications/message_handler/message_queue.cpp index e896349..c3890fe 100644 --- a/applications/message_handler/message_queue.cpp +++ b/applications/message_handler/message_queue.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * diff --git a/applications/message_handler/message_queue.hpp b/applications/message_handler/message_queue.hpp index 7c59e75..4140c62 100644 --- a/applications/message_handler/message_queue.hpp +++ b/applications/message_handler/message_queue.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * diff --git a/applications/message_handler/model_template.hpp b/applications/message_handler/model_template.hpp index 06636b2..353d7d3 100644 --- a/applications/message_handler/model_template.hpp +++ b/applications/message_handler/model_template.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Arm Limited. All rights reserved. + * Copyright (c) 2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * -- cgit v1.2.1