aboutsummaryrefslogtreecommitdiff
path: root/utils/inference_runner/inference_runner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'utils/inference_runner/inference_runner.cpp')
-rw-r--r--utils/inference_runner/inference_runner.cpp78
1 files changed, 62 insertions, 16 deletions
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp
index 3388170..ae8d2a7 100644
--- a/utils/inference_runner/inference_runner.cpp
+++ b/utils/inference_runner/inference_runner.cpp
@@ -16,8 +16,8 @@
* limitations under the License.
*/
-#include <ethosu.hpp>
+#include <ethosu.hpp>
#include <uapi/ethosu.h>
#include <unistd.h>
@@ -77,7 +77,55 @@ shared_ptr<Buffer> allocAndFill(Device &device, const string filename)
return buffer;
}
-std::ostream &operator<<(std::ostream &os, Buffer &buf)
+shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &network, const string &filename)
+{
+ // Open IFM file
+ ifstream stream(filename, ios::binary);
+ if (!stream.is_open())
+ {
+ cerr << "Error: Failed to open '" << filename << "'" << endl;
+ exit(1);
+ }
+
+ // Get IFM file size
+ stream.seekg(0, ios_base::end);
+ size_t size = stream.tellg();
+ stream.seekg(0, ios_base::beg);
+
+ if (size != network->getIfmSize())
+ {
+ cerr << "Error: IFM size does not match network size. filename=" << filename << ", size=" << size << ", network=" << network->getIfmSize() << endl;
+ exit(1);
+ }
+
+ // Create IFM buffers
+ vector<shared_ptr<Buffer>> ifm;
+ for (auto size: network->getIfmDims())
+ {
+ shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size);
+ buffer->resize(size);
+ stream.read(buffer->data(), size);
+
+ if (!stream)
+ {
+ cerr << "Error: Failed to read IFM" << endl;
+ exit(1);
+ }
+
+ ifm.push_back(buffer);
+ }
+
+ // Create OFM buffers
+ vector<shared_ptr<Buffer>> ofm;
+ for (auto size: network->getOfmDims())
+ {
+ ofm.push_back(make_shared<Buffer>(device, size));
+ }
+
+ return make_shared<Inference>(network, ifm.begin(), ifm.end(), ofm.begin(), ofm.end());
+}
+
+ostream &operator<<(ostream &os, Buffer &buf)
{
char *c = buf.data();
const char *end = c + buf.size();
@@ -128,7 +176,7 @@ int main(int argc, char *argv[])
else if (arg == "--timeout" || arg == "-t")
{
rangeCheck(++i, argc, arg);
- timeout = std::stoi(argv[i]);
+ timeout = stoi(argv[i]);
}
else if (arg == "-p")
{
@@ -167,20 +215,17 @@ int main(int argc, char *argv[])
cout << "Send ping" << endl;
device.ioctl(ETHOSU_IOCTL_PING);
+ /* Create network */
cout << "Create network" << endl;
shared_ptr<Buffer> networkBuffer = allocAndFill(device, networkArg);
shared_ptr<Network> network = make_shared<Network>(device, networkBuffer);
- cout << "Queue inferences" << endl;
+ /* Create one inference per IFM */
list<shared_ptr<Inference>> inferences;
-
for (auto &filename: ifmArg)
{
cout << "Create inference" << endl;
- shared_ptr<Buffer> ifmBuffer = allocAndFill(device, filename);
- shared_ptr<Buffer> ofmBuffer = make_shared<Buffer>(device, 128 * 1024);
- shared_ptr<Inference> inference = make_shared<Inference>(network, ifmBuffer, ofmBuffer);
- inferences.push_back(inference);
+ inferences.push_back(createInference(device, network, filename));
}
cout << "Wait for inferences" << endl;
@@ -203,16 +248,17 @@ int main(int argc, char *argv[])
if (!inference->failed())
{
- shared_ptr<Buffer> ofmBuffer = inference->getOfmBuffer();
+ for (auto &ofmBuffer: inference->getOfmBuffers())
+ {
+ cout << "OFM size: " << ofmBuffer->size() << endl;
- cout << "OFM size: " << ofmBuffer->size() << endl;
+ if (print)
+ {
+ cout << "OFM data: " << *ofmBuffer << endl;
+ }
- if (print)
- {
- cout << "OFM data: " << *ofmBuffer << endl;
+ ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
-
- ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
ofmIndex++;