aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavide Grohmann <davide.grohmann@arm.com>2022-04-07 15:01:34 +0200
committerDavide Grohmann <davide.grohmann@arm.com>2022-04-12 12:54:29 +0200
commit65520055f422dac8d3c1a87314d98846a3bce3b5 (patch)
tree69f4092c798e38ec0d21a18fabc1a347f8b83dd9
parent06ebcbcde891635b3f9c5b208b373b8d973e8ac6 (diff)
downloadethos-u-core-platform-65520055f422dac8d3c1a87314d98846a3bce3b5.tar.gz
Extract inference parsing in a dedicated library
Change-Id: I2753434badec7c5af2c19a2b32e5e808131ba519
-rw-r--r--applications/message_handler/message_handler.cpp155
-rw-r--r--applications/message_handler/message_handler.hpp2
2 files changed, 10 insertions, 147 deletions
diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp
index 6630c0a..a9c7df7 100644
--- a/applications/message_handler/message_handler.cpp
+++ b/applications/message_handler/message_handler.cpp
@@ -19,7 +19,6 @@
#include "message_handler.hpp"
#include "cmsis_compiler.h"
-#include "tensorflow/lite/schema/schema_generated.h"
#ifdef ETHOSU
#include <ethosu_driver.h>
@@ -81,36 +80,6 @@ namespace MessageHandler {
****************************************************************************/
namespace {
-
-template <typename T, typename U>
-class Array {
-public:
- Array() = delete;
- Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {}
-
- auto size() const {
- return _size;
- }
-
- auto capacity() const {
- return _capacity;
- }
-
- void push_back(const T &data) {
- _data[_size++] = data;
- }
-
-private:
- T *const _data;
- U &_size;
- const size_t _capacity{};
-};
-
-template <typename T, typename U>
-Array<T, U> makeArray(T *const data, U &size, size_t capacity) {
- return Array<T, U>{data, size, capacity};
-}
-
bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) {
data = reinterpret_cast<void *>(buffer.ptr);
size = buffer.size;
@@ -167,118 +136,6 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s
}
}
-bool getShapeSize(const flatbuffers::Vector<int32_t> *shape, size_t &size) {
- size = 1;
-
- if (shape == nullptr) {
- printf("Warning: nullptr shape size.\n");
- return true;
- }
-
- if (shape->Length() == 0) {
- printf("Warning: shape zero length.\n");
- return true;
- }
-
- for (auto it = shape->begin(); it != shape->end(); ++it) {
- size *= *it;
- }
-
- return false;
-}
-
-bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) {
- switch (type) {
- case tflite::TensorType::TensorType_UINT8:
- case tflite::TensorType::TensorType_INT8:
- size = 1;
- break;
- case tflite::TensorType::TensorType_INT16:
- size = 2;
- break;
- case tflite::TensorType::TensorType_INT32:
- case tflite::TensorType::TensorType_FLOAT32:
- size = 4;
- break;
- default:
- printf("Warning: Unsupported tensor type\n");
- return true;
- }
-
- return false;
-}
-
-template <typename T>
-bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap, T &dims) {
- if (subgraph == nullptr || tensorMap == nullptr) {
- printf("Warning: nullptr subgraph or tensormap.\n");
- return true;
- }
-
- if ((dims.capacity() - dims.size()) < tensorMap->size()) {
- printf("Warning: tensormap size is larger than dimension capacity.\n");
- return true;
- }
-
- for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
- auto tensor = subgraph->tensors()->Get(*index);
- size_t size;
- size_t tensorSize;
-
- bool failed = getShapeSize(tensor->shape(), size);
- if (failed) {
- return true;
- }
-
- failed = getTensorTypeSize(tensor->type(), tensorSize);
- if (failed) {
- return true;
- }
-
- size *= tensorSize;
-
- if (size > 0) {
- dims.push_back(size);
- }
- }
-
- return false;
-}
-
-template <typename T, typename U, size_t S>
-bool parseModel(const ethosu_core_network_buffer &buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) {
- void *data;
- size_t size;
- bool failed = getNetwork(buffer, data, size);
- if (failed) {
- return true;
- }
-
- // Create model handle
- const tflite::Model *model = tflite::GetModel(reinterpret_cast<const void *>(data));
- if (model->subgraphs() == nullptr) {
- printf("Warning: nullptr subgraph\n");
- return true;
- }
-
- strncpy(description, model->description()->c_str(), sizeof(description));
-
- // Get input dimensions for first subgraph
- auto *subgraph = *model->subgraphs()->begin();
- failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims);
- if (failed) {
- return true;
- }
-
- // Get output dimensions for last subgraph
- subgraph = *model->subgraphs()->rbegin();
- failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims);
- if (failed) {
- return true;
- }
-
- return false;
-}
}; // namespace
IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue,
@@ -442,10 +299,14 @@ bool IncomingMessageHandler::handleMessage() {
rsp.ifm_count = 0;
rsp.ofm_count = 0;
- bool failed = parseModel(req.network,
- rsp.desc,
- makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX),
- makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX));
+ void *buffer;
+ size_t size;
+ getNetwork(req.network, buffer, size);
+ bool failed =
+ parser.parseModel(buffer,
+ rsp.desc,
+ InferenceProcess::makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX),
+ InferenceProcess::makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX));
rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK;
diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp
index 13a3c60..b7152f7 100644
--- a/applications/message_handler/message_handler.hpp
+++ b/applications/message_handler/message_handler.hpp
@@ -27,6 +27,7 @@
#if defined(ETHOSU)
#include <ethosu_driver.h>
#endif
+#include <inference_parser.hpp>
#include <inference_process.hpp>
#include <mailbox.hpp>
@@ -51,6 +52,7 @@ private:
MessageQueue::QueueImpl messageQueue;
Mailbox::Mailbox &mailbox;
+ InferenceProcess::InferenceParser parser;
QueueHandle_t inferenceQueue;
QueueHandle_t outputQueue;
SemaphoreHandle_t semaphore;