From 80f8ddf050d594ec663b01cb800e9547c9f919d0 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Thu, 7 Apr 2022 14:50:50 +0200 Subject: Drop tflite and flatbuffer dependencies There is not real need for the linux_driver_stack to know about tflite and flatbuffers. A better approach is to just pass the buffer to the firmware to be processes, i.e., either parsed or executed. This solves issues when linux_driver_stack and firmware are not in sync with the same library versions. Change-Id: I9b2a12e69f37f61b1ac594433a15883fb1c67b9c --- driver_library/src/ethosu.cpp | 121 ++++++++++-------------------------------- 1 file changed, 28 insertions(+), 93 deletions(-) (limited to 'driver_library/src/ethosu.cpp') diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp index f792399..01631b3 100644 --- a/driver_library/src/ethosu.cpp +++ b/driver_library/src/ethosu.cpp @@ -16,8 +16,6 @@ * limitations under the License. */ -#include "autogen/tflite_schema.hpp" - #include #include @@ -92,61 +90,6 @@ __attribute__((weak)) int emunmap(void *addr, size_t length) { } // namespace EthosU -/**************************************************************************** - * TFL micro helpers - ****************************************************************************/ -namespace { -size_t getShapeSize(const flatbuffers::Vector *shape) { - size_t size = 1; - - if (shape == nullptr) { - throw EthosU::Exception("getShapeSize(): nullptr arg"); - } - - for (auto it = shape->begin(); it != shape->end(); ++it) { - size *= *it; - } - - return size; -} - -size_t getTensorTypeSize(const enum tflite::TensorType type) { - switch (type) { - case tflite::TensorType::TensorType_UINT8: - case tflite::TensorType::TensorType_INT8: - return 1; - case tflite::TensorType::TensorType_INT16: - return 2; - case tflite::TensorType::TensorType_INT32: - case tflite::TensorType::TensorType_FLOAT32: - return 4; - default: - throw EthosU::Exception("Unsupported tensor type"); - } -} - -vector getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector *tensorMap) { - vector dims; - - if (subgraph == nullptr || tensorMap == nullptr) { - throw EthosU::Exception("getSubGraphDims(): nullptr arg(s)"); - } - - for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) { - auto tensor = subgraph->tensors()->Get(*index); - size_t size = getShapeSize(tensor->shape()); - size *= getTensorTypeSize(tensor->type()); - - if (size > 0) { - dims.push_back(size); - } - } - - return dims; -} - -} // namespace - namespace EthosU { /**************************************************************************** @@ -247,8 +190,13 @@ Buffer::Buffer(const Device &device, const size_t capacity) : fd(-1), dataPtr(nu } Buffer::~Buffer() { - emunmap(dataPtr, dataCapacity); - eclose(fd); + try { + emunmap(dataPtr, dataCapacity); + } catch (std::exception &e) { + try { + eclose(fd); + } catch (...) { std::throw_with_nested(e); } + } } size_t Buffer::capacity() const { @@ -296,12 +244,12 @@ Network::Network(const Device &device, shared_ptr &buffer) : fd(-1), buf uapi.type = ETHOSU_UAPI_NETWORK_BUFFER; uapi.fd = buffer->getFd(); fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast(&uapi)); - try { - parseModel(buffer->data()); - } catch (...) { - eclose(fd); - throw; + collectNetworkInfo(); + } catch (std::exception &e) { + try { + eclose(fd); + } catch (...) { std::throw_with_nested(e); } } } @@ -311,21 +259,25 @@ Network::Network(const Device &device, const unsigned index) : fd(-1) { uapi.type = ETHOSU_UAPI_NETWORK_INDEX; uapi.index = index; fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast(&uapi)); - try { - ethosu_uapi_network_info info; - ioctl(ETHOSU_IOCTL_NETWORK_INFO, static_cast(&info)); + collectNetworkInfo(); + } catch (std::exception &e) { + try { + eclose(fd); + } catch (...) { std::throw_with_nested(e); } + } +} - for (uint32_t i = 0; i < info.ifm_count; i++) { - ifmDims.push_back(info.ifm_size[i]); - } +void Network::collectNetworkInfo() { + ethosu_uapi_network_info info; + ioctl(ETHOSU_IOCTL_NETWORK_INFO, static_cast(&info)); - for (uint32_t i = 0; i < info.ofm_count; i++) { - ofmDims.push_back(info.ofm_size[i]); - } - } catch (...) { - eclose(fd); - throw; + for (uint32_t i = 0; i < info.ifm_count; i++) { + ifmDims.push_back(info.ifm_size[i]); + } + + for (uint32_t i = 0; i < info.ofm_count; i++) { + ofmDims.push_back(info.ofm_size[i]); } } @@ -369,23 +321,6 @@ size_t Network::getOfmSize() const { return size; } -void Network::parseModel(const char *data) { - // Create model handle - const tflite::Model *model = tflite::GetModel(reinterpret_cast(data)); - - if (model->subgraphs() == nullptr) { - EthosU::Exception("Failed to get subgraphs: nullptr"); - } - - // Get input dimensions for first subgraph - auto *subgraph = *model->subgraphs()->begin(); - ifmDims = getSubGraphDims(subgraph, subgraph->inputs()); - - // Get output dimensions for last subgraph - subgraph = *model->subgraphs()->rbegin(); - ofmDims = getSubGraphDims(subgraph, subgraph->outputs()); -} - /**************************************************************************** * Inference ****************************************************************************/ -- cgit v1.2.1