aboutsummaryrefslogtreecommitdiff
path: root/driver_library/src/ethosu.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'driver_library/src/ethosu.cpp')
-rw-r--r--driver_library/src/ethosu.cpp151
1 files changed, 133 insertions, 18 deletions
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp
index 6b2b3b1..39e39f0 100644
--- a/driver_library/src/ethosu.cpp
+++ b/driver_library/src/ethosu.cpp
@@ -16,6 +16,8 @@
* limitations under the License.
*/
+#include "autogen/tflite_schema.hpp"
+
#include <ethosu.hpp>
#include <uapi/ethosu.h>
@@ -43,11 +45,50 @@ int eioctl(int fd, unsigned long cmd, void *data = nullptr)
return ret;
}
+
+/****************************************************************************
+ * TFL micro helpers
+ ****************************************************************************/
+
+size_t getShapeSize(const flatbuffers::Vector<int32_t> *shape)
+{
+ size_t size = 1;
+
+ for (auto it = shape->begin(); it != shape->end(); ++it)
+ {
+ size *= *it;
+ }
+
+ return size;
+}
+
+vector<size_t> getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap)
+{
+ vector<size_t> dims;
+
+ for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index)
+ {
+ auto tensor = subgraph->tensors()->Get(*index);
+ size_t size = getShapeSize(tensor->shape());
+
+ if (size > 0)
+ {
+ dims.push_back(size);
+ }
+ }
+
+ return dims;
+}
+
}
namespace EthosU
{
+/****************************************************************************
+ * Exception
+ ****************************************************************************/
+
Exception::Exception(const char *msg) :
msg(msg)
{}
@@ -60,6 +101,10 @@ const char *Exception::what() const throw()
return msg.c_str();
}
+/****************************************************************************
+ * Device
+ ****************************************************************************/
+
Device::Device(const char *device)
{
fd = open(device, O_RDWR | O_NONBLOCK);
@@ -79,6 +124,10 @@ int Device::ioctl(unsigned long cmd, void *data)
return eioctl(fd, cmd, data);
}
+/****************************************************************************
+ * Buffer
+ ****************************************************************************/
+
Buffer::Buffer(Device &device, const size_t capacity) :
fd(-1),
dataPtr(nullptr),
@@ -121,7 +170,6 @@ void Buffer::resize(size_t size, size_t offset)
ethosu_uapi_buffer uapi;
uapi.offset = offset;
uapi.size = size;
-
eioctl(fd, ETHOSU_IOCTL_BUFFER_SET, static_cast<void *>(&uapi));
}
@@ -144,15 +192,29 @@ int Buffer::getFd() const
return fd;
}
+/****************************************************************************
+ * Network
+ ****************************************************************************/
+
Network::Network(Device &device, shared_ptr<Buffer> &buffer) :
fd(-1),
buffer(buffer)
{
+ // Create buffer handle
ethosu_uapi_network_create uapi;
-
uapi.fd = buffer->getFd();
-
fd = device.ioctl(ETHOSU_IOCTL_NETWORK_CREATE, static_cast<void *>(&uapi));
+
+ // Create model handle
+ const tflite::Model *model = tflite::GetModel(reinterpret_cast<void *>(buffer->data()));
+
+ // Get input dimensions for first subgraph
+ auto *subgraph = *model->subgraphs()->begin();
+ ifmDims = getSubGraphDims(subgraph, subgraph->inputs());
+
+ // Get output dimensions for last subgraph
+ subgraph = *model->subgraphs()->rbegin();
+ ofmDims = getSubGraphDims(subgraph, subgraph->outputs());
}
Network::~Network()
@@ -165,30 +227,83 @@ int Network::ioctl(unsigned long cmd, void *data)
return eioctl(fd, cmd, data);
}
-std::shared_ptr<Buffer> Network::getBuffer()
+shared_ptr<Buffer> Network::getBuffer()
{
return buffer;
}
-Inference::Inference(std::shared_ptr<Network> &network, std::shared_ptr<Buffer> &ifmBuffer, std::shared_ptr<Buffer> &ofmBuffer) :
- fd(-1),
- network(network),
- ifmBuffer(ifmBuffer),
- ofmBuffer(ofmBuffer)
+const std::vector<size_t> &Network::getIfmDims() const
{
- ethosu_uapi_inference_create uapi;
+ return ifmDims;
+}
- uapi.ifm_fd = ifmBuffer->getFd();
- uapi.ofm_fd = ofmBuffer->getFd();
+size_t Network::getIfmSize() const
+{
+ size_t size = 0;
- fd = network->ioctl(ETHOSU_IOCTL_INFERENCE_CREATE, static_cast<void *>(&uapi));
+ for (auto s: ifmDims)
+ {
+ size += s;
+ }
+
+ return size;
+}
+
+const std::vector<size_t> &Network::getOfmDims() const
+{
+ return ofmDims;
+}
+
+size_t Network::getOfmSize() const
+{
+ size_t size = 0;
+
+ for (auto s: ofmDims)
+ {
+ size += s;
+ }
+
+ return size;
}
+/****************************************************************************
+ * Inference
+ ****************************************************************************/
+
Inference::~Inference()
{
close(fd);
}
+void Inference::create()
+{
+ ethosu_uapi_inference_create uapi;
+
+ if (ifmBuffers.size() > ETHOSU_FD_MAX)
+ {
+ throw Exception("IFM buffer overflow");
+ }
+
+ if (ofmBuffers.size() > ETHOSU_FD_MAX)
+ {
+ throw Exception("OFM buffer overflow");
+ }
+
+ uapi.ifm_count = 0;
+ for (auto it: ifmBuffers)
+ {
+ uapi.ifm_fd[uapi.ifm_count++] = it->getFd();
+ }
+
+ uapi.ofm_count = 0;
+ for (auto it: ofmBuffers)
+ {
+ uapi.ofm_fd[uapi.ofm_count++] = it->getFd();
+ }
+
+ fd = network->ioctl(ETHOSU_IOCTL_INFERENCE_CREATE, static_cast<void *>(&uapi));
+}
+
void Inference::wait(int timeoutSec)
{
pollfd pfd;
@@ -214,19 +329,19 @@ int Inference::getFd()
return fd;
}
-std::shared_ptr<Network> Inference::getNetwork()
+shared_ptr<Network> Inference::getNetwork()
{
return network;
}
-std::shared_ptr<Buffer> Inference::getIfmBuffer()
+vector<shared_ptr<Buffer>> &Inference::getIfmBuffers()
{
- return ifmBuffer;
+ return ifmBuffers;
}
-std::shared_ptr<Buffer> Inference::getOfmBuffer()
+vector<shared_ptr<Buffer>> &Inference::getOfmBuffers()
{
- return ofmBuffer;
+ return ofmBuffers;
}
}