aboutsummaryrefslogtreecommitdiff
path: root/driver_library/include/ethosu.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'driver_library/include/ethosu.hpp')
-rw-r--r--driver_library/include/ethosu.hpp27
1 files changed, 22 insertions, 5 deletions
diff --git a/driver_library/include/ethosu.hpp b/driver_library/include/ethosu.hpp
index 3c8f814..c3a310c 100644
--- a/driver_library/include/ethosu.hpp
+++ b/driver_library/include/ethosu.hpp
@@ -20,8 +20,10 @@
#include <uapi/ethosu.h>
+#include <algorithm>
#include <memory>
#include <string>
+#include <vector>
namespace EthosU
{
@@ -79,30 +81,45 @@ public:
int ioctl(unsigned long cmd, void *data = nullptr);
std::shared_ptr<Buffer> getBuffer();
+ const std::vector<size_t> &getIfmDims() const;
+ size_t getIfmSize() const;
+ const std::vector<size_t> &getOfmDims() const;
+ size_t getOfmSize() const;
private:
int fd;
std::shared_ptr<Buffer> buffer;
+ std::vector<size_t> ifmDims;
+ std::vector<size_t> ofmDims;
};
class Inference
{
public:
- Inference(std::shared_ptr<Network> &network, std::shared_ptr<Buffer> &ifm, std::shared_ptr<Buffer> &ofm);
+ template <typename T>
+ Inference(std::shared_ptr<Network> &network, const T &ifmBegin, const T &ifmEnd, const T &ofmBegin, const T &ofmEnd) :
+ network(network)
+ {
+ std::copy(ifmBegin, ifmEnd, std::back_inserter(ifmBuffers));
+ std::copy(ofmBegin, ofmEnd, std::back_inserter(ofmBuffers));
+ create();
+ }
virtual ~Inference();
void wait(int timeoutSec = -1);
bool failed();
int getFd();
std::shared_ptr<Network> getNetwork();
- std::shared_ptr<Buffer> getIfmBuffer();
- std::shared_ptr<Buffer> getOfmBuffer();
+ std::vector<std::shared_ptr<Buffer>> &getIfmBuffers();
+ std::vector<std::shared_ptr<Buffer>> &getOfmBuffers();
private:
+ void create();
+
int fd;
std::shared_ptr<Network> network;
- std::shared_ptr<Buffer> ifmBuffer;
- std::shared_ptr<Buffer> ofmBuffer;
+ std::vector<std::shared_ptr<Buffer>> ifmBuffers;
+ std::vector<std::shared_ptr<Buffer>> ofmBuffers;
};
}