diff options
Diffstat (limited to 'driver_library/src/ethosu.cpp')
-rw-r--r-- | driver_library/src/ethosu.cpp | 71 |
1 files changed, 55 insertions, 16 deletions
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp index 32d179a..16d2db0 100644 --- a/driver_library/src/ethosu.cpp +++ b/driver_library/src/ethosu.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. All rights reserved. + * Copyright (c) 2020-2022 Arm Limited. All rights reserved. * * SPDX-License-Identifier: Apache-2.0 * @@ -23,6 +23,7 @@ #include <algorithm> #include <exception> +#include <fstream> #include <iostream> #include <fcntl.h> @@ -292,25 +293,46 @@ int Buffer::getFd() const { Network::Network(const Device &device, shared_ptr<Buffer> &buffer) : fd(-1), buffer(buffer) { // Create buffer handle ethosu_uapi_network_create uapi; - uapi.fd = buffer->getFd(); - fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi)); + uapi.type = ETHOSU_UAPI_NETWORK_BUFFER; + uapi.fd = buffer->getFd(); + fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi)); - // Create model handle - const tflite::Model *model = tflite::GetModel(reinterpret_cast<void *>(buffer->data())); - - if (model->subgraphs() == nullptr) { - try { - eclose(fd); - } catch (...) { std::throw_with_nested(EthosU::Exception("Failed to get subgraphs: nullptr")); } + try { + parseModel(buffer->data()); + } catch (...) { + eclose(fd); + throw; } +} - // Get input dimensions for first subgraph - auto *subgraph = *model->subgraphs()->begin(); - ifmDims = getSubGraphDims(subgraph, subgraph->inputs()); +Network::Network(const Device &device, const string &model, const unsigned index) : fd(-1) { + // Create buffer handle + ethosu_uapi_network_create uapi; + uapi.type = ETHOSU_UAPI_NETWORK_INDEX; + uapi.index = index; + fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi)); - // Get output dimensions for last subgraph - subgraph = *model->subgraphs()->rbegin(); - ofmDims = getSubGraphDims(subgraph, subgraph->outputs()); + try { + // Open file + ifstream ifs(model, std::ios::binary); + if (!ifs.is_open()) { + throw Exception("Failed to open model file."); + } + + // Get file size + ifs.seekg(0, ios_base::end); + size_t size = ifs.tellg(); + ifs.seekg(0, ios_base::beg); + + // Read data into buffer + vector<char> buffer(size); + ifs.read(buffer.data(), size); + + parseModel(buffer.data()); + } catch (...) { + eclose(fd); + throw; + } } Network::~Network() { @@ -353,6 +375,23 @@ size_t Network::getOfmSize() const { return size; } +void Network::parseModel(const char *data) { + // Create model handle + const tflite::Model *model = tflite::GetModel(reinterpret_cast<const void *>(data)); + + if (model->subgraphs() == nullptr) { + EthosU::Exception("Failed to get subgraphs: nullptr"); + } + + // Get input dimensions for first subgraph + auto *subgraph = *model->subgraphs()->begin(); + ifmDims = getSubGraphDims(subgraph, subgraph->inputs()); + + // Get output dimensions for last subgraph + subgraph = *model->subgraphs()->rbegin(); + ofmDims = getSubGraphDims(subgraph, subgraph->outputs()); +} + /**************************************************************************** * Inference ****************************************************************************/ |