diff options
Diffstat (limited to 'driver_library/src/ethosu.cpp')
-rw-r--r-- | driver_library/src/ethosu.cpp | 151 |
1 files changed, 133 insertions, 18 deletions
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp index 6b2b3b1..39e39f0 100644 --- a/driver_library/src/ethosu.cpp +++ b/driver_library/src/ethosu.cpp @@ -16,6 +16,8 @@ * limitations under the License. */ +#include "autogen/tflite_schema.hpp" + #include <ethosu.hpp> #include <uapi/ethosu.h> @@ -43,11 +45,50 @@ int eioctl(int fd, unsigned long cmd, void *data = nullptr) return ret; } + +/**************************************************************************** + * TFL micro helpers + ****************************************************************************/ + +size_t getShapeSize(const flatbuffers::Vector<int32_t> *shape) +{ + size_t size = 1; + + for (auto it = shape->begin(); it != shape->end(); ++it) + { + size *= *it; + } + + return size; +} + +vector<size_t> getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap) +{ + vector<size_t> dims; + + for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) + { + auto tensor = subgraph->tensors()->Get(*index); + size_t size = getShapeSize(tensor->shape()); + + if (size > 0) + { + dims.push_back(size); + } + } + + return dims; +} + } namespace EthosU { +/**************************************************************************** + * Exception + ****************************************************************************/ + Exception::Exception(const char *msg) : msg(msg) {} @@ -60,6 +101,10 @@ const char *Exception::what() const throw() return msg.c_str(); } +/**************************************************************************** + * Device + ****************************************************************************/ + Device::Device(const char *device) { fd = open(device, O_RDWR | O_NONBLOCK); @@ -79,6 +124,10 @@ int Device::ioctl(unsigned long cmd, void *data) return eioctl(fd, cmd, data); } +/**************************************************************************** + * Buffer + ****************************************************************************/ + Buffer::Buffer(Device &device, const size_t capacity) : fd(-1), dataPtr(nullptr), @@ -121,7 +170,6 @@ void Buffer::resize(size_t size, size_t offset) ethosu_uapi_buffer uapi; uapi.offset = offset; uapi.size = size; - eioctl(fd, ETHOSU_IOCTL_BUFFER_SET, static_cast<void *>(&uapi)); } @@ -144,15 +192,29 @@ int Buffer::getFd() const return fd; } +/**************************************************************************** + * Network + ****************************************************************************/ + Network::Network(Device &device, shared_ptr<Buffer> &buffer) : fd(-1), buffer(buffer) { + // Create buffer handle ethosu_uapi_network_create uapi; - uapi.fd = buffer->getFd(); - fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi)); + + // Create model handle + const tflite::Model *model = tflite::GetModel(reinterpret_cast<void *>(buffer->data())); + + // Get input dimensions for first subgraph + auto *subgraph = *model->subgraphs()->begin(); + ifmDims = getSubGraphDims(subgraph, subgraph->inputs()); + + // Get output dimensions for last subgraph + subgraph = *model->subgraphs()->rbegin(); + ofmDims = getSubGraphDims(subgraph, subgraph->outputs()); } Network::~Network() @@ -165,30 +227,83 @@ int Network::ioctl(unsigned long cmd, void *data) return eioctl(fd, cmd, data); } -std::shared_ptr<Buffer> Network::getBuffer() +shared_ptr<Buffer> Network::getBuffer() { return buffer; } -Inference::Inference(std::shared_ptr<Network> &network, std::shared_ptr<Buffer> &ifmBuffer, std::shared_ptr<Buffer> &ofmBuffer) : - fd(-1), - network(network), - ifmBuffer(ifmBuffer), - ofmBuffer(ofmBuffer) +const std::vector<size_t> &Network::getIfmDims() const { - ethosu_uapi_inference_create uapi; + return ifmDims; +} - uapi.ifm_fd = ifmBuffer->getFd(); - uapi.ofm_fd = ofmBuffer->getFd(); +size_t Network::getIfmSize() const +{ + size_t size = 0; - fd = network->ioctl(ETHOSU_IOCTL_INFERENCE_CREATE, static_cast<void *>(&uapi)); + for (auto s: ifmDims) + { + size += s; + } + + return size; +} + +const std::vector<size_t> &Network::getOfmDims() const +{ + return ofmDims; +} + +size_t Network::getOfmSize() const +{ + size_t size = 0; + + for (auto s: ofmDims) + { + size += s; + } + + return size; } +/**************************************************************************** + * Inference + ****************************************************************************/ + Inference::~Inference() { close(fd); } +void Inference::create() +{ + ethosu_uapi_inference_create uapi; + + if (ifmBuffers.size() > ETHOSU_FD_MAX) + { + throw Exception("IFM buffer overflow"); + } + + if (ofmBuffers.size() > ETHOSU_FD_MAX) + { + throw Exception("OFM buffer overflow"); + } + + uapi.ifm_count = 0; + for (auto it: ifmBuffers) + { + uapi.ifm_fd[uapi.ifm_count++] = it->getFd(); + } + + uapi.ofm_count = 0; + for (auto it: ofmBuffers) + { + uapi.ofm_fd[uapi.ofm_count++] = it->getFd(); + } + + fd = network->ioctl(ETHOSU_IOCTL_INFERENCE_CREATE, static_cast<void *>(&uapi)); +} + void Inference::wait(int timeoutSec) { pollfd pfd; @@ -214,19 +329,19 @@ int Inference::getFd() return fd; } -std::shared_ptr<Network> Inference::getNetwork() +shared_ptr<Network> Inference::getNetwork() { return network; } -std::shared_ptr<Buffer> Inference::getIfmBuffer() +vector<shared_ptr<Buffer>> &Inference::getIfmBuffers() { - return ifmBuffer; + return ifmBuffers; } -std::shared_ptr<Buffer> Inference::getOfmBuffer() +vector<shared_ptr<Buffer>> &Inference::getOfmBuffers() { - return ofmBuffer; + return ofmBuffers; } } |