aboutsummaryrefslogtreecommitdiff
path: root/driver_library/include/ethosu.hpp
diff options
context:
space:
mode:
authorKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-10 13:26:01 +0200
committerKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-17 13:23:27 +0200
commitb74492c5aee3786b886951e87f4e5ea8d6032733 (patch)
tree76ef44dfdb68d68964877b0adba21cbce2416fe5 /driver_library/include/ethosu.hpp
parent116a635581f292cb4882ea1a086f842904f85c3c (diff)
downloadethos-u-linux-driver-stack-b74492c5aee3786b886951e87f4e5ea8d6032733.tar.gz
Support inferences with multiple inputs and outputs
Build flatbuffers library. Update network class to extract IFM and OFM dimensions from the tflite file. Update the uapi and core apis to support up to 16 IFM and OFM buffers per inference. Change-Id: I2f2f177aa4c2d5f9f50f23eb33c44e01ec2cbe09
Diffstat (limited to 'driver_library/include/ethosu.hpp')
-rw-r--r--driver_library/include/ethosu.hpp27
1 files changed, 22 insertions, 5 deletions
diff --git a/driver_library/include/ethosu.hpp b/driver_library/include/ethosu.hpp
index 3c8f814..c3a310c 100644
--- a/driver_library/include/ethosu.hpp
+++ b/driver_library/include/ethosu.hpp
@@ -20,8 +20,10 @@
#include <uapi/ethosu.h>
+#include <algorithm>
#include <memory>
#include <string>
+#include <vector>
namespace EthosU
{
@@ -79,30 +81,45 @@ public:
int ioctl(unsigned long cmd, void *data = nullptr);
std::shared_ptr<Buffer> getBuffer();
+ const std::vector<size_t> &getIfmDims() const;
+ size_t getIfmSize() const;
+ const std::vector<size_t> &getOfmDims() const;
+ size_t getOfmSize() const;
private:
int fd;
std::shared_ptr<Buffer> buffer;
+ std::vector<size_t> ifmDims;
+ std::vector<size_t> ofmDims;
};
class Inference
{
public:
- Inference(std::shared_ptr<Network> &network, std::shared_ptr<Buffer> &ifm, std::shared_ptr<Buffer> &ofm);
+ template <typename T>
+ Inference(std::shared_ptr<Network> &network, const T &ifmBegin, const T &ifmEnd, const T &ofmBegin, const T &ofmEnd) :
+ network(network)
+ {
+ std::copy(ifmBegin, ifmEnd, std::back_inserter(ifmBuffers));
+ std::copy(ofmBegin, ofmEnd, std::back_inserter(ofmBuffers));
+ create();
+ }
virtual ~Inference();
void wait(int timeoutSec = -1);
bool failed();
int getFd();
std::shared_ptr<Network> getNetwork();
- std::shared_ptr<Buffer> getIfmBuffer();
- std::shared_ptr<Buffer> getOfmBuffer();
+ std::vector<std::shared_ptr<Buffer>> &getIfmBuffers();
+ std::vector<std::shared_ptr<Buffer>> &getOfmBuffers();
private:
+ void create();
+
int fd;
std::shared_ptr<Network> network;
- std::shared_ptr<Buffer> ifmBuffer;
- std::shared_ptr<Buffer> ofmBuffer;
+ std::vector<std::shared_ptr<Buffer>> ifmBuffers;
+ std::vector<std::shared_ptr<Buffer>> ofmBuffers;
};
}