From 65520055f422dac8d3c1a87314d98846a3bce3b5 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Thu, 7 Apr 2022 15:01:34 +0200 Subject: Extract inference parsing in a dedicated library Change-Id: I2753434badec7c5af2c19a2b32e5e808131ba519 --- applications/message_handler/message_handler.cpp | 155 ++--------------------- applications/message_handler/message_handler.hpp | 2 + 2 files changed, 10 insertions(+), 147 deletions(-) diff --git a/applications/message_handler/message_handler.cpp b/applications/message_handler/message_handler.cpp index 6630c0a..a9c7df7 100644 --- a/applications/message_handler/message_handler.cpp +++ b/applications/message_handler/message_handler.cpp @@ -19,7 +19,6 @@ #include "message_handler.hpp" #include "cmsis_compiler.h" -#include "tensorflow/lite/schema/schema_generated.h" #ifdef ETHOSU #include @@ -81,36 +80,6 @@ namespace MessageHandler { ****************************************************************************/ namespace { - -template -class Array { -public: - Array() = delete; - Array(T *const data, U &size, size_t capacity) : _data{data}, _size{size}, _capacity{capacity} {} - - auto size() const { - return _size; - } - - auto capacity() const { - return _capacity; - } - - void push_back(const T &data) { - _data[_size++] = data; - } - -private: - T *const _data; - U &_size; - const size_t _capacity{}; -}; - -template -Array makeArray(T *const data, U &size, size_t capacity) { - return Array{data, size, capacity}; -} - bool getNetwork(const ethosu_core_buffer &buffer, void *&data, size_t &size) { data = reinterpret_cast(buffer.ptr); size = buffer.size; @@ -167,118 +136,6 @@ bool getNetwork(const ethosu_core_network_buffer &buffer, void *&data, size_t &s } } -bool getShapeSize(const flatbuffers::Vector *shape, size_t &size) { - size = 1; - - if (shape == nullptr) { - printf("Warning: nullptr shape size.\n"); - return true; - } - - if (shape->Length() == 0) { - printf("Warning: shape zero length.\n"); - return true; - } - - for (auto it = shape->begin(); it != shape->end(); ++it) { - size *= *it; - } - - return false; -} - -bool getTensorTypeSize(const enum tflite::TensorType type, size_t &size) { - switch (type) { - case tflite::TensorType::TensorType_UINT8: - case tflite::TensorType::TensorType_INT8: - size = 1; - break; - case tflite::TensorType::TensorType_INT16: - size = 2; - break; - case tflite::TensorType::TensorType_INT32: - case tflite::TensorType::TensorType_FLOAT32: - size = 4; - break; - default: - printf("Warning: Unsupported tensor type\n"); - return true; - } - - return false; -} - -template -bool getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector *tensorMap, T &dims) { - if (subgraph == nullptr || tensorMap == nullptr) { - printf("Warning: nullptr subgraph or tensormap.\n"); - return true; - } - - if ((dims.capacity() - dims.size()) < tensorMap->size()) { - printf("Warning: tensormap size is larger than dimension capacity.\n"); - return true; - } - - for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) { - auto tensor = subgraph->tensors()->Get(*index); - size_t size; - size_t tensorSize; - - bool failed = getShapeSize(tensor->shape(), size); - if (failed) { - return true; - } - - failed = getTensorTypeSize(tensor->type(), tensorSize); - if (failed) { - return true; - } - - size *= tensorSize; - - if (size > 0) { - dims.push_back(size); - } - } - - return false; -} - -template -bool parseModel(const ethosu_core_network_buffer &buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) { - void *data; - size_t size; - bool failed = getNetwork(buffer, data, size); - if (failed) { - return true; - } - - // Create model handle - const tflite::Model *model = tflite::GetModel(reinterpret_cast(data)); - if (model->subgraphs() == nullptr) { - printf("Warning: nullptr subgraph\n"); - return true; - } - - strncpy(description, model->description()->c_str(), sizeof(description)); - - // Get input dimensions for first subgraph - auto *subgraph = *model->subgraphs()->begin(); - failed = getSubGraphDims(subgraph, subgraph->inputs(), ifmDims); - if (failed) { - return true; - } - - // Get output dimensions for last subgraph - subgraph = *model->subgraphs()->rbegin(); - failed = getSubGraphDims(subgraph, subgraph->outputs(), ofmDims); - if (failed) { - return true; - } - - return false; -} }; // namespace IncomingMessageHandler::IncomingMessageHandler(ethosu_core_queue &_messageQueue, @@ -442,10 +299,14 @@ bool IncomingMessageHandler::handleMessage() { rsp.ifm_count = 0; rsp.ofm_count = 0; - bool failed = parseModel(req.network, - rsp.desc, - makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX), - makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX)); + void *buffer; + size_t size; + getNetwork(req.network, buffer, size); + bool failed = + parser.parseModel(buffer, + rsp.desc, + InferenceProcess::makeArray(rsp.ifm_size, rsp.ifm_count, ETHOSU_CORE_BUFFER_MAX), + InferenceProcess::makeArray(rsp.ofm_size, rsp.ofm_count, ETHOSU_CORE_BUFFER_MAX)); rsp.status = failed ? ETHOSU_CORE_STATUS_ERROR : ETHOSU_CORE_STATUS_OK; diff --git a/applications/message_handler/message_handler.hpp b/applications/message_handler/message_handler.hpp index 13a3c60..b7152f7 100644 --- a/applications/message_handler/message_handler.hpp +++ b/applications/message_handler/message_handler.hpp @@ -27,6 +27,7 @@ #if defined(ETHOSU) #include #endif +#include #include #include @@ -51,6 +52,7 @@ private: MessageQueue::QueueImpl messageQueue; Mailbox::Mailbox &mailbox; + InferenceProcess::InferenceParser parser; QueueHandle_t inferenceQueue; QueueHandle_t outputQueue; SemaphoreHandle_t semaphore; -- cgit v1.2.1