aboutsummaryrefslogtreecommitdiff
path: root/driver_library
diff options
context:
space:
mode:
authorKristofer Jonsson <kristofer.jonsson@arm.com>2022-03-08 13:25:45 +0100
committerKristofer Jonsson <kristofer.jonsson@arm.com>2022-03-10 15:20:57 +0100
commit35de9e63d9c2fe0a557637ac104d7d73382d2d4a (patch)
tree41fa348f46f7f76b00625ad3b9768c1ddae5c83b /driver_library
parent118b05990af26026a1ac2b6d5dfae32ea342a7f4 (diff)
downloadethos-u-linux-driver-stack-35de9e63d9c2fe0a557637ac104d7d73382d2d4a.tar.gz
Firmware resident model
Support referencing a network model by index that has been built into the firmware binary. Change-Id: Idd5294376ea82503dfeafe1203dcc0694d296dfe
Diffstat (limited to 'driver_library')
-rw-r--r--driver_library/include/ethosu.hpp5
-rw-r--r--driver_library/src/ethosu.cpp71
2 files changed, 59 insertions, 17 deletions
diff --git a/driver_library/include/ethosu.hpp b/driver_library/include/ethosu.hpp
index 98e6969..0738aa2 100644
--- a/driver_library/include/ethosu.hpp
+++ b/driver_library/include/ethosu.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited. All rights reserved.
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -161,6 +161,7 @@ private:
class Network {
public:
Network(const Device &device, std::shared_ptr<Buffer> &buffer);
+ Network(const Device &device, const std::string &model, const unsigned index);
virtual ~Network();
int ioctl(unsigned long cmd, void *data = nullptr);
@@ -171,6 +172,8 @@ public:
size_t getOfmSize() const;
private:
+ void parseModel(const char *data);
+
int fd;
std::shared_ptr<Buffer> buffer;
std::vector<size_t> ifmDims;
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp
index 32d179a..16d2db0 100644
--- a/driver_library/src/ethosu.cpp
+++ b/driver_library/src/ethosu.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2020-2022 Arm Limited. All rights reserved.
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -23,6 +23,7 @@
#include <algorithm>
#include <exception>
+#include <fstream>
#include <iostream>
#include <fcntl.h>
@@ -292,25 +293,46 @@ int Buffer::getFd() const {
Network::Network(const Device &device, shared_ptr<Buffer> &buffer) : fd(-1), buffer(buffer) {
// Create buffer handle
ethosu_uapi_network_create uapi;
- uapi.fd = buffer->getFd();
- fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi));
+ uapi.type = ETHOSU_UAPI_NETWORK_BUFFER;
+ uapi.fd = buffer->getFd();
+ fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi));
- // Create model handle
- const tflite::Model *model = tflite::GetModel(reinterpret_cast<void *>(buffer->data()));
-
- if (model->subgraphs() == nullptr) {
- try {
- eclose(fd);
- } catch (...) { std::throw_with_nested(EthosU::Exception("Failed to get subgraphs: nullptr")); }
+ try {
+ parseModel(buffer->data());
+ } catch (...) {
+ eclose(fd);
+ throw;
}
+}
- // Get input dimensions for first subgraph
- auto *subgraph = *model->subgraphs()->begin();
- ifmDims = getSubGraphDims(subgraph, subgraph->inputs());
+Network::Network(const Device &device, const string &model, const unsigned index) : fd(-1) {
+ // Create buffer handle
+ ethosu_uapi_network_create uapi;
+ uapi.type = ETHOSU_UAPI_NETWORK_INDEX;
+ uapi.index = index;
+ fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi));
- // Get output dimensions for last subgraph
- subgraph = *model->subgraphs()->rbegin();
- ofmDims = getSubGraphDims(subgraph, subgraph->outputs());
+ try {
+ // Open file
+ ifstream ifs(model, std::ios::binary);
+ if (!ifs.is_open()) {
+ throw Exception("Failed to open model file.");
+ }
+
+ // Get file size
+ ifs.seekg(0, ios_base::end);
+ size_t size = ifs.tellg();
+ ifs.seekg(0, ios_base::beg);
+
+ // Read data into buffer
+ vector<char> buffer(size);
+ ifs.read(buffer.data(), size);
+
+ parseModel(buffer.data());
+ } catch (...) {
+ eclose(fd);
+ throw;
+ }
}
Network::~Network() {
@@ -353,6 +375,23 @@ size_t Network::getOfmSize() const {
return size;
}
+void Network::parseModel(const char *data) {
+ // Create model handle
+ const tflite::Model *model = tflite::GetModel(reinterpret_cast<const void *>(data));
+
+ if (model->subgraphs() == nullptr) {
+ EthosU::Exception("Failed to get subgraphs: nullptr");
+ }
+
+ // Get input dimensions for first subgraph
+ auto *subgraph = *model->subgraphs()->begin();
+ ifmDims = getSubGraphDims(subgraph, subgraph->inputs());
+
+ // Get output dimensions for last subgraph
+ subgraph = *model->subgraphs()->rbegin();
+ ofmDims = getSubGraphDims(subgraph, subgraph->outputs());
+}
+
/****************************************************************************
* Inference
****************************************************************************/