From 35de9e63d9c2fe0a557637ac104d7d73382d2d4a Mon Sep 17 00:00:00 2001 From: Kristofer Jonsson Date: Tue, 8 Mar 2022 13:25:45 +0100 Subject: Firmware resident model Support referencing a network model by index that has been built into the firmware binary. Change-Id: Idd5294376ea82503dfeafe1203dcc0694d296dfe --- driver_library/include/ethosu.hpp | 5 ++- driver_library/src/ethosu.cpp | 71 ++++++++++++++++++++++++++++++--------- 2 files changed, 59 insertions(+), 17 deletions(-) (limited to 'driver_library') diff --git a/driver_library/include/ethosu.hpp b/driver_library/include/ethosu.hpp index 98e6969..0738aa2 100644 --- a/driver_library/include/ethosu.hpp +++ b/driver_library/include/ethosu.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. All rights reserved. * * SPDX-License-Identifier: Apache-2.0 * @@ -161,6 +161,7 @@ private: class Network { public: Network(const Device &device, std::shared_ptr &buffer); + Network(const Device &device, const std::string &model, const unsigned index); virtual ~Network(); int ioctl(unsigned long cmd, void *data = nullptr); @@ -171,6 +172,8 @@ public: size_t getOfmSize() const; private: + void parseModel(const char *data); + int fd; std::shared_ptr buffer; std::vector ifmDims; diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp index 32d179a..16d2db0 100644 --- a/driver_library/src/ethosu.cpp +++ b/driver_library/src/ethosu.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. All rights reserved. * * SPDX-License-Identifier: Apache-2.0 * @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -292,25 +293,46 @@ int Buffer::getFd() const { Network::Network(const Device &device, shared_ptr &buffer) : fd(-1), buffer(buffer) { // Create buffer handle ethosu_uapi_network_create uapi; - uapi.fd = buffer->getFd(); - fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast(&uapi)); + uapi.type = ETHOSU_UAPI_NETWORK_BUFFER; + uapi.fd = buffer->getFd(); + fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast(&uapi)); - // Create model handle - const tflite::Model *model = tflite::GetModel(reinterpret_cast(buffer->data())); - - if (model->subgraphs() == nullptr) { - try { - eclose(fd); - } catch (...) { std::throw_with_nested(EthosU::Exception("Failed to get subgraphs: nullptr")); } + try { + parseModel(buffer->data()); + } catch (...) { + eclose(fd); + throw; } +} - // Get input dimensions for first subgraph - auto *subgraph = *model->subgraphs()->begin(); - ifmDims = getSubGraphDims(subgraph, subgraph->inputs()); +Network::Network(const Device &device, const string &model, const unsigned index) : fd(-1) { + // Create buffer handle + ethosu_uapi_network_create uapi; + uapi.type = ETHOSU_UAPI_NETWORK_INDEX; + uapi.index = index; + fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast(&uapi)); - // Get output dimensions for last subgraph - subgraph = *model->subgraphs()->rbegin(); - ofmDims = getSubGraphDims(subgraph, subgraph->outputs()); + try { + // Open file + ifstream ifs(model, std::ios::binary); + if (!ifs.is_open()) { + throw Exception("Failed to open model file."); + } + + // Get file size + ifs.seekg(0, ios_base::end); + size_t size = ifs.tellg(); + ifs.seekg(0, ios_base::beg); + + // Read data into buffer + vector buffer(size); + ifs.read(buffer.data(), size); + + parseModel(buffer.data()); + } catch (...) { + eclose(fd); + throw; + } } Network::~Network() { @@ -353,6 +375,23 @@ size_t Network::getOfmSize() const { return size; } +void Network::parseModel(const char *data) { + // Create model handle + const tflite::Model *model = tflite::GetModel(reinterpret_cast(data)); + + if (model->subgraphs() == nullptr) { + EthosU::Exception("Failed to get subgraphs: nullptr"); + } + + // Get input dimensions for first subgraph + auto *subgraph = *model->subgraphs()->begin(); + ifmDims = getSubGraphDims(subgraph, subgraph->inputs()); + + // Get output dimensions for last subgraph + subgraph = *model->subgraphs()->rbegin(); + ofmDims = getSubGraphDims(subgraph, subgraph->outputs()); +} + /**************************************************************************** * Inference ****************************************************************************/ -- cgit v1.2.1