aboutsummaryrefslogtreecommitdiff
path: root/utils/inference_runner/inference_runner.cpp
diff options
context:
space:
mode:
authorKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-10 13:26:01 +0200
committerKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-17 13:23:27 +0200
commitb74492c5aee3786b886951e87f4e5ea8d6032733 (patch)
tree76ef44dfdb68d68964877b0adba21cbce2416fe5 /utils/inference_runner/inference_runner.cpp
parent116a635581f292cb4882ea1a086f842904f85c3c (diff)
downloadethos-u-linux-driver-stack-b74492c5aee3786b886951e87f4e5ea8d6032733.tar.gz
Support inferences with multiple inputs and outputs
Build flatbuffers library. Update network class to extract IFM and OFM dimensions from the tflite file. Update the uapi and core apis to support up to 16 IFM and OFM buffers per inference. Change-Id: I2f2f177aa4c2d5f9f50f23eb33c44e01ec2cbe09
Diffstat (limited to 'utils/inference_runner/inference_runner.cpp')
-rw-r--r--utils/inference_runner/inference_runner.cpp78
1 files changed, 62 insertions, 16 deletions
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp
index 3388170..ae8d2a7 100644
--- a/utils/inference_runner/inference_runner.cpp
+++ b/utils/inference_runner/inference_runner.cpp
@@ -16,8 +16,8 @@
* limitations under the License.
*/
-#include <ethosu.hpp>
+#include <ethosu.hpp>
#include <uapi/ethosu.h>
#include <unistd.h>
@@ -77,7 +77,55 @@ shared_ptr<Buffer> allocAndFill(Device &device, const string filename)
return buffer;
}
-std::ostream &operator<<(std::ostream &os, Buffer &buf)
+shared_ptr<Inference> createInference(Device &device, shared_ptr<Network> &network, const string &filename)
+{
+ // Open IFM file
+ ifstream stream(filename, ios::binary);
+ if (!stream.is_open())
+ {
+ cerr << "Error: Failed to open '" << filename << "'" << endl;
+ exit(1);
+ }
+
+ // Get IFM file size
+ stream.seekg(0, ios_base::end);
+ size_t size = stream.tellg();
+ stream.seekg(0, ios_base::beg);
+
+ if (size != network->getIfmSize())
+ {
+ cerr << "Error: IFM size does not match network size. filename=" << filename << ", size=" << size << ", network=" << network->getIfmSize() << endl;
+ exit(1);
+ }
+
+ // Create IFM buffers
+ vector<shared_ptr<Buffer>> ifm;
+ for (auto size: network->getIfmDims())
+ {
+ shared_ptr<Buffer> buffer = make_shared<Buffer>(device, size);
+ buffer->resize(size);
+ stream.read(buffer->data(), size);
+
+ if (!stream)
+ {
+ cerr << "Error: Failed to read IFM" << endl;
+ exit(1);
+ }
+
+ ifm.push_back(buffer);
+ }
+
+ // Create OFM buffers
+ vector<shared_ptr<Buffer>> ofm;
+ for (auto size: network->getOfmDims())
+ {
+ ofm.push_back(make_shared<Buffer>(device, size));
+ }
+
+ return make_shared<Inference>(network, ifm.begin(), ifm.end(), ofm.begin(), ofm.end());
+}
+
+ostream &operator<<(ostream &os, Buffer &buf)
{
char *c = buf.data();
const char *end = c + buf.size();
@@ -128,7 +176,7 @@ int main(int argc, char *argv[])
else if (arg == "--timeout" || arg == "-t")
{
rangeCheck(++i, argc, arg);
- timeout = std::stoi(argv[i]);
+ timeout = stoi(argv[i]);
}
else if (arg == "-p")
{
@@ -167,20 +215,17 @@ int main(int argc, char *argv[])
cout << "Send ping" << endl;
device.ioctl(ETHOSU_IOCTL_PING);
+ /* Create network */
cout << "Create network" << endl;
shared_ptr<Buffer> networkBuffer = allocAndFill(device, networkArg);
shared_ptr<Network> network = make_shared<Network>(device, networkBuffer);
- cout << "Queue inferences" << endl;
+ /* Create one inference per IFM */
list<shared_ptr<Inference>> inferences;
-
for (auto &filename: ifmArg)
{
cout << "Create inference" << endl;
- shared_ptr<Buffer> ifmBuffer = allocAndFill(device, filename);
- shared_ptr<Buffer> ofmBuffer = make_shared<Buffer>(device, 128 * 1024);
- shared_ptr<Inference> inference = make_shared<Inference>(network, ifmBuffer, ofmBuffer);
- inferences.push_back(inference);
+ inferences.push_back(createInference(device, network, filename));
}
cout << "Wait for inferences" << endl;
@@ -203,16 +248,17 @@ int main(int argc, char *argv[])
if (!inference->failed())
{
- shared_ptr<Buffer> ofmBuffer = inference->getOfmBuffer();
+ for (auto &ofmBuffer: inference->getOfmBuffers())
+ {
+ cout << "OFM size: " << ofmBuffer->size() << endl;
- cout << "OFM size: " << ofmBuffer->size() << endl;
+ if (print)
+ {
+ cout << "OFM data: " << *ofmBuffer << endl;
+ }
- if (print)
- {
- cout << "OFM data: " << *ofmBuffer << endl;
+ ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
-
- ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
ofmIndex++;