aboutsummaryrefslogtreecommitdiff
path: root/utils/inference_runner/inference_runner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'utils/inference_runner/inference_runner.cpp')
-rw-r--r--utils/inference_runner/inference_runner.cpp27
1 files changed, 18 insertions, 9 deletions
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp
index 21e133c..9df67ed 100644
--- a/utils/inference_runner/inference_runner.cpp
+++ b/utils/inference_runner/inference_runner.cpp
@@ -259,22 +259,29 @@ int main(int argc, char *argv[]) {
/* make sure the wait completes ok */
try {
cout << "Wait for inference" << endl;
- inference->wait(timeout);
+ bool timedout = inference->wait(timeout);
+ if (timedout) {
+ cout << "Inference timed out, cancelling it" << endl;
+ bool aborted = inference->cancel();
+ if (!aborted || inference->status() != InferenceStatus::ABORTED) {
+ cout << "Inference cancellation failed" << endl;
+ }
+ }
} catch (std::exception &e) {
- cout << "Failed to wait for inference completion: " << e.what() << endl;
+ cout << "Failed to wait for or to cancel inference: " << e.what() << endl;
exit(1);
}
cout << "Inference status: " << inference->status() << endl;
- string ofmFilename = ofmArg + "." + to_string(ofmIndex);
- ofstream ofmStream(ofmFilename, ios::binary);
- if (!ofmStream.is_open()) {
- cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
- exit(1);
- }
-
if (inference->status() == InferenceStatus::OK) {
+ string ofmFilename = ofmArg + "." + to_string(ofmIndex);
+ ofstream ofmStream(ofmFilename, ios::binary);
+ if (!ofmStream.is_open()) {
+ cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
+ exit(1);
+ }
+
/* The inference completed and has ok status */
for (auto &ofmBuffer : inference->getOfmBuffers()) {
cout << "OFM size: " << ofmBuffer->size() << endl;
@@ -286,6 +293,8 @@ int main(int argc, char *argv[]) {
ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
+ ofmStream.flush();
+
/* Read out PMU counters if configured */
if (std::count(enabledCounters.begin(), enabledCounters.end(), 0) <
Inference::getMaxPmuEventCounters()) {