aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristofer Jonsson <kristofer.jonsson@arm.com>2022-03-10 11:08:39 +0100
committerKristofer Jonsson <kristofer.jonsson@arm.com>2022-03-29 15:12:35 +0200
commitac535f0387647b114826e921d23e68787f8a572b (patch)
tree38ae30cadc45fffe81bee654635c24b7b8e25cf6
parent585ce694dbbebfe5ba737fe94888343cb8976ac3 (diff)
downloadethos-u-core-platform-ac535f0387647b114826e921d23e68787f8a572b.tar.gz
Network info
Add message for fetching meta data about built in network models. Change-Id: I757094c20848d4cb018db68b0455297bb03be463
-rw-r--r--applications/message_handler/CMakeLists.txt2
-rw-r--r--applications/message_handler/main.cpp4
-rw-r--r--applications/message_handler/message_handler.cpp184
-rw-r--r--applications/message_handler/message_handler.hpp4
-rw-r--r--applications/message_handler/message_queue.cpp2
-rw-r--r--applications/message_handler/message_queue.hpp2
-rw-r--r--applications/message_handler/model_template.hpp2
7 files changed, 192 insertions, 8 deletions
diff --git a/applications/message_handler/CMakeLists.txt b/applications/message_handler/CMakeLists.txt
index 72d930f..27a4815 100644
--- a/applications/message_handler/CMakeLists.txt
+++ b/applications/message_handler/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# Copyright (c) 2020-2022 Arm Limited. All rights reserved.
+# Copyright (c) 2020-2022 Arm Limited.
#
# SPDX-License-Identifier: Apache-2.0
#
diff --git a/applications/message_handler/main.cpp b/applications/message_handler/main.cpp
index 9b36f84..8a36325 100644
--- a/applications/message_handler/main.cpp
+++ b/applications/message_handler/main.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -182,7 +182,7 @@ int main() {
outputQueue = xQueueCreate(10, sizeof(OutputMessage));
// Task for handling incoming messages from the remote host
- ret = xTaskCreate(inputMessageTask, "inputMessageTask", 512, nullptr, 2, nullptr);
+ ret = xTaskCreate(inputMessageTask, "inputMessageTask", 1024, nullptr, 2, nullptr);
if (ret != pdPASS) {
printf("Failed to create 'inputMessageTask'\n");
return ret;
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp
index e530712..4b77389 100644
--- a/applications/message_handler/message_handler.cpp
+++ b/applications/message_handler/message_handler.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -19,6 +19,7 @@
#include "message_handler.hpp"
#include "cmsis_compiler.h"
+#include "tensorflow/lite/schema/schema_generated.h"
#ifdef ETHOSU
#include <ethosu_driver.h>
@@ -31,6 +32,7 @@
#include <cstring>
#include <inttypes.h>
+#include <vector>
#define XSTRINGIFY(src) #src
#define STRINGIFY(src) XSTRINGIFY(src)
@@ -79,6 +81,36 @@ namespace MessageHandler {
****************************************************************************/
namespace {
+
+template <typename T, typename U>
+class Array {
+public:
+ Array() = delete;
+ Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {}
+
+ auto size() const {
+ return _size;
+ }
+
+ auto capacity() const {
+ return _capacity;
+ }
+
+ void push_back(const T &data) {
+ _data[_size++] = data;
+ }
+
+private:
+ T *const _data;
+ U &_size;
+ const size_t _capacity{};
+};
+
+template <typename T, typename U>
+Array<T, U> makeArray(T *const data, U &size, size_t capacity) {
+ return Array<T, U>{data, size, capacity};
+}
+
bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) {
data = reinterpret_cast<void *>(buffer.ptr);
size = buffer.size;
@@ -134,6 +166,119 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s
return true;
}
}
+
+bool getShapeSize(const flatbuffers::Vector<int32_t> *shape, size_t &size) {
+ size = 1;
+
+ if (shape == nullptr) {
+ printf("Warning: nullptr shape size.\n");
+ return true;
+ }
+
+ if (shape->Length() == 0) {
+ printf("Warning: shape zero length.\n");
+ return true;
+ }
+
+ for (auto it = shape->begin(); it != shape->end(); ++it) {
+ size *= *it;
+ }
+
+ return false;
+}
+
+bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) {
+ switch (type) {
+ case tflite::TensorType::TensorType_UINT8:
+ case tflite::TensorType::TensorType_INT8:
+ size = 1;
+ break;
+ case tflite::TensorType::TensorType_INT16:
+ size = 2;
+ break;
+ case tflite::TensorType::TensorType_INT32:
+ case tflite::TensorType::TensorType_FLOAT32:
+ size = 4;
+ break;
+ default:
+ printf("Warning: Unsupported tensor type\n");
+ return true;
+ }
+
+ return false;
+}
+
+template <typename T>
+bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap, T &dims) {
+ if (subgraph == nullptr || tensorMap == nullptr) {
+ printf("Warning: nullptr subgraph or tensormap.\n");
+ return true;
+ }
+
+ if ((dims.capacity() - dims.size()) < tensorMap->size()) {
+ printf("Warning: tensormap size is larger than dimension capacity.\n");
+ return true;
+ }
+
+ for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
+ auto tensor = subgraph->tensors()->Get(*index);
+ size_t size;
+ size_t tensorSize;
+
+ bool failed = getShapeSize(tensor->shape(), size);
+ if (failed) {
+ return true;
+ }
+
+ failed = getTensorTypeSize(tensor->type(), tensorSize);
+ if (failed) {
+ return true;
+ }
+
+ size *= tensorSize;
+
+ if (size > 0) {
+ dims.push_back(size);
+ }
+ }
+
+ return false;
+}
+
+template <typename T, typename U, size_t S>
+bool parseModel(const ethosu_core_network_buffer &buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) {
+ void *data;
+ size_t size;
+ bool failed = getNetwork(buffer, data, size);
+ if (failed) {
+ return true;
+ }
+
+ // Create model handle
+ const tflite::Model *model = tflite::GetModel(reinterpret_cast<const void *>(data));
+ if (model->subgraphs() == nullptr) {
+ printf("Warning: nullptr subgraph\n");
+ return true;
+ }
+
+ strncpy(description, model->description()->c_str(), sizeof(description));
+
+ // Get input dimensions for first subgraph
+ auto *subgraph = *model->subgraphs()->begin();
+ failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims);
+ if (failed) {
+ return true;
+ }
+
+ // Get output dimensions for last subgraph
+ subgraph = *model->subgraphs()->rbegin();
+ failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims);
+ if (failed) {
+ return true;
+ }
+
+ return false;
+}
}; // namespace
IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue,
@@ -281,6 +426,32 @@ bool IncomingMessageHandler::handleMessage() {
xQueueSend(inferenceQueue, &inference, portMAX_DELAY);
break;
}
+ case ETHOSU_CORE_MSG_NETWORK_INFO_REQ: {
+ ethosu_core_network_info_req req;
+
+ if (!messageQueue.read(req)) {
+ queueErrorAndResetQueue(ETHOSU_CORE_MSG_ERR_INVALID_PAYLOAD, "NetworkInfoReq. Failed to read payload");
+ break;
+ }
+
+ printf("Msg: NetworkInfoReq. user_arg=0x%" PRIx64 "\n", req.user_arg);
+
+ OutputMessage message(ETHOSU_CORE_MSG_NETWORK_INFO_RSP);
+ ethosu_core_network_info_rsp &rsp = message.data.networkInfo;
+ rsp.user_arg = req.user_arg;
+ rsp.ifm_count = 0;
+ rsp.ofm_count = 0;
+
+ bool failed = parseModel(req.network,
+ rsp.desc,
+ makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX),
+ makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX));
+
+ rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
+
+ xQueueSend(outputQueue, &message, portMAX_DELAY);
+ break;
+ }
default: {
char errMsg[128];
@@ -428,6 +599,9 @@ void OutgoingMessageHandler::run() {
case ETHOSU_CORE_MSG_ERR:
sendErrorRsp(message.data.error);
break;
+ case ETHOSU_CORE_MSG_NETWORK_INFO_RSP:
+ sendNetworkInfoRsp(message.data.networkInfo);
+ break;
default:
printf("Dropping unknown outcome of type %d\n", message.type);
break;
@@ -476,6 +650,14 @@ void OutgoingMessageHandler::sendInferenceRsp(ethosu_core_inference_rsp &inferen
}
}
+void OutgoingMessageHandler::sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo) {
+ if (!messageQueue.write(ETHOSU_CORE_MSG_NETWORK_INFO_RSP, networkInfo)) {
+ printf("ERROR: Msg: Failed to write network info response. No mailbox message sent\n");
+ } else {
+ mailbox.sendMessage();
+ }
+}
+
void OutgoingMessageHandler::sendErrorRsp(ethosu_core_msg_err &error) {
printf("ERROR: Msg: \"%s\"\n", error.msg);
diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp
index ee063de..90b1cd2 100644
--- a/applications/message_handler/message_handler.hpp
+++ b/applications/message_handler/message_handler.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -81,6 +81,7 @@ struct OutputMessage {
EthosU::ethosu_core_msg_type type;
union {
EthosU::ethosu_core_inference_rsp inference;
+ EthosU::ethosu_core_network_info_rsp networkInfo;
EthosU::ethosu_core_msg_err error;
uint64_t userArg;
} data;
@@ -99,6 +100,7 @@ private:
void sendVersionRsp();
void sendCapabilitiesRsp(uint64_t userArg);
void sendInferenceRsp(EthosU::ethosu_core_inference_rsp &inference);
+ void sendNetworkInfoRsp(EthosU::ethosu_core_network_info_rsp &networkInfo);
void readCapabilties(EthosU::ethosu_core_msg_capabilities_rsp &rsp);
MessageQueue::QueueImpl messageQueue;
diff --git a/applications/message_handler/message_queue.cpp b/applications/message_handler/message_queue.cpp
index e896349..c3890fe 100644
--- a/applications/message_handler/message_queue.cpp
+++ b/applications/message_handler/message_queue.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
*
* SPDX-License-Identifier: Apache-2.0
*
diff --git a/applications/message_handler/message_queue.hpp b/applications/message_handler/message_queue.hpp
index 7c59e75..4140c62 100644
--- a/applications/message_handler/message_queue.hpp
+++ b/applications/message_handler/message_queue.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited.
*
* SPDX-License-Identifier: Apache-2.0
*
diff --git a/applications/message_handler/model_template.hpp b/applications/message_handler/model_template.hpp
index 06636b2..353d7d3 100644
--- a/applications/message_handler/model_template.hpp
+++ b/applications/message_handler/model_template.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022 Arm Limited. All rights reserved.
+ * Copyright (c) 2022 Arm Limited.
*
* SPDX-License-Identifier: Apache-2.0
*