diff options
author | Kristofer Jonsson <kristofer.jonsson@arm.com> | 2022-03-10 11:08:39 +0100 |
---|---|---|
committer | Kristofer Jonsson <kristofer.jonsson@arm.com> | 2022-03-29 15:12:35 +0200 |
commit | ac535f0387647b114826e921d23e68787f8a572b (patch) | |
tree | 38ae30cadc45fffe81bee654635c24b7b8e25cf6 /applications/message_handler/message_handler.cpp | |
parent | 585ce694dbbebfe5ba737fe94888343cb8976ac3 (diff) | |
download | ethos-u-core-platform-ac535f0387647b114826e921d23e68787f8a572b.tar.gz |
Network info
Add message for fetching meta data about built in network models.
Change-Id: I757094c20848d4cb018db68b0455297bb03be463
Diffstat (limited to 'applications/message_handler/message_handler.cpp')
-rw-r--r-- | applications/message_handler/message_handler.cpp | 184 |
1 files changed, 183 insertions, 1 deletions
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp index e530712..4b77389 100644 --- a/applications/message_handler/message_handler.cpp +++ b/applications/message_handler/message_handler.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: Apache-2.0 * @@ -19,6 +19,7 @@ #include "message_handler.hpp" #include "cmsis_compiler.h" +#include "tensorflow/lite/schema/schema_generated.h" #ifdef ETHOSU #include <ethosu_driver.h> @@ -31,6 +32,7 @@ #include <cstring> #include <inttypes.h> +#include <vector> #define XSTRINGIFY(src) #src #define STRINGIFY(src) XSTRINGIFY(src) @@ -79,6 +81,36 @@ namespace MessageHandler { ****************************************************************************/ namespace { + +template <typename T, typename U> +class Array { +public: + Array() = delete; + Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {} + + auto size() const { + return _size; + } + + auto capacity() const { + return _capacity; + } + + void push_back(const T &data) { + _data[_size++] = data; + } + +private: + T *const _data; + U &_size; + const size_t _capacity{}; +}; + +template <typename T, typename U> +Array<T, U> makeArray(T *const data, U &size, size_t capacity) { + return Array<T, U>{data, size, capacity}; +} + bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) { data = reinterpret_cast<void *>(buffer.ptr); size = buffer.size; @@ -134,6 +166,119 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s return true; } } + +bool getShapeSize(const flatbuffers::Vector<int32_t> *shape, size_t &size) { + size = 1; + + if (shape == nullptr) { + printf("Warning: nullptr shape size.\n"); + return true; + } + + if (shape->Length() == 0) { + printf("Warning: shape zero length.\n"); + return true; + } + + for (auto it = shape->begin(); it != shape->end(); ++it) { + size *= *it; + } + + return false; +} + +bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) { + switch (type) { + case tflite::TensorType::TensorType_UINT8: + case tflite::TensorType::TensorType_INT8: + size = 1; + break; + case tflite::TensorType::TensorType_INT16: + size = 2; + break; + case tflite::TensorType::TensorType_INT32: + case tflite::TensorType::TensorType_FLOAT32: + size = 4; + break; + default: + printf("Warning: Unsupported tensor type\n"); + return true; + } + + return false; +} + +template <typename T> +bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap, T &dims) { + if (subgraph == nullptr || tensorMap == nullptr) { + printf("Warning: nullptr subgraph or tensormap.\n"); + return true; + } + + if ((dims.capacity() - dims.size()) < tensorMap->size()) { + printf("Warning: tensormap size is larger than dimension capacity.\n"); + return true; + } + + for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) { + auto tensor = subgraph->tensors()->Get(*index); + size_t size; + size_t tensorSize; + + bool failed = getShapeSize(tensor->shape(), size); + if (failed) { + return true; + } + + failed = getTensorTypeSize(tensor->type(), tensorSize); + if (failed) { + return true; + } + + size *= tensorSize; + + if (size > 0) { + dims.push_back(size); + } + } + + return false; +} + +template <typename T, typename U, size_t S> +bool parseModel(const ethosu_core_network_buffer &buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) { + void *data; + size_t size; + bool failed = getNetwork(buffer, data, size); + if (failed) { + return true; + } + + // Create model handle + const tflite::Model *model = tflite::GetModel(reinterpret_cast<const void *>(data)); + if (model->subgraphs() == nullptr) { + printf("Warning: nullptr subgraph\n"); + return true; + } + + strncpy(description, model->description()->c_str(), sizeof(description)); + + // Get input dimensions for first subgraph + auto *subgraph = *model->subgraphs()->begin(); + failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims); + if (failed) { + return true; + } + + // Get output dimensions for last subgraph + subgraph = *model->subgraphs()->rbegin(); + failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims); + if (failed) { + return true; + } + + return false; +} }; // namespace IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue, @@ -281,6 +426,32 @@ bool IncomingMessageHandler::handleMessage() { xQueueSend(inferenceQueue, &inference, portMAX_DELAY); break; } + case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: { + ethosu_core_network_info_req req; + + if (!messageQueue.read(req)) { + queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "NetworkInfoReq. Failed to read payload"); + break; + } + + printf("Msg: NetworkInfoReq. user_arg=0x%" PRIx64 "\n", req.user_arg); + + OutputMessage message(ETHOSU_CORE_MSG_NETWORK_INFO_RSP); + ethosu_core_network_info_rsp &rsp = message.data.networkInfo; + rsp.user_arg = req.user_arg; + rsp.ifm_count = 0; + rsp.ofm_count = 0; + + bool failed = parseModel(req.network, + rsp.desc, + makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX), + makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX)); + + rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; + + xQueueSend(outputQueue, &message, portMAX_DELAY); + break; + } default: { char errMsg[128]; @@ -428,6 +599,9 @@ void OutgoingMessageHandler::run() { case ETHOSU_CORE_MSG_ERR: sendErrorRsp(message.data.error); break; + case ETHOSU_CORE_MSG_NETWORK_INFO_RSP: + sendNetworkInfoRsp(message.data.networkInfo); + break; default: printf("Dropping unknown outcome of type %d\n", message.type); break; @@ -476,6 +650,14 @@ void OutgoingMessageHandler::sendInferenceRsp(ethosu_core_inference_rsp &inferen } } +void OutgoingMessageHandler::sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo) { + if (!messageQueue.write(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, networkInfo)) { + printf("ERROR: Msg: Failed to write network info response. No mailbox message sent\n"); + } else { + mailbox.sendMessage(); + } +} + void OutgoingMessageHandler::sendErrorRsp(ethosu_core_msg_err &error) { printf("ERROR: Msg: \"%s\"\n", error.msg); |