aboutsummaryrefslogtreecommitdiff
path: root/utils/inference_runner/inference_runner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'utils/inference_runner/inference_runner.cpp')
-rw-r--r--utils/inference_runner/inference_runner.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp
index c7a18b0..721cd57 100644
--- a/utils/inference_runner/inference_runner.cpp
+++ b/utils/inference_runner/inference_runner.cpp
@@ -25,6 +25,7 @@
#include <stdio.h>
#include <string>
#include <unistd.h>
+#include <utility>
using namespace std;
using namespace EthosU;
@@ -56,7 +57,7 @@ void rangeCheck(const int i, const int argc, const string arg) {
}
}
-shared_ptr<Buffer> allocAndFill(Device &device, const string filename) {
+pair<unique_ptr<unsigned char[]>, size_t> getNetworkData(const string filename) {
ifstream stream(filename, ios::binary);
if (!stream.is_open()) {
cerr << "Error: Failed to open '" << filename << "'" << endl;
@@ -67,10 +68,10 @@ shared_ptr<Buffer> allocAndFill(Device &device, const string filename) {
size_t size = stream.tellg();
stream.seekg(0, ios_base::beg);
- shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size);
- stream.read(buffer->data(), size);
+ unique_ptr<unsigned char[]> data = std::make_unique<unsigned char[]>(size);
+ stream.read(reinterpret_cast<char *>(data.get()), size);
- return buffer;
+ return make_pair(std::move(data), size);
}
shared_ptr<Inference> createInference(Device &device,
@@ -234,8 +235,8 @@ int main(int argc, char *argv[]) {
shared_ptr<Network> network;
if (networkIndex < 0) {
- shared_ptr<Buffer> networkBuffer = allocAndFill(device, networkArg);
- network = make_shared<Network>(device, networkBuffer);
+ auto networkData = getNetworkData(networkArg);
+ network = make_shared<Network>(device, networkData.first.get(), networkData.second);
} else {
network = make_shared<Network>(device, networkIndex);
}