aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--driver_library/include/ethosu.hpp5
-rw-r--r--driver_library/src/ethosu.cpp18
-rw-r--r--kernel/Kbuild3
-rw-r--r--kernel/ethosu_cancel_inference.c198
-rw-r--r--kernel/ethosu_cancel_inference.h72
-rw-r--r--kernel/ethosu_core_interface.h20
-rw-r--r--kernel/ethosu_device.c16
-rw-r--r--kernel/ethosu_inference.c51
-rw-r--r--kernel/ethosu_inference.h2
-rw-r--r--kernel/ethosu_mailbox.c14
-rw-r--r--kernel/ethosu_mailbox.h9
-rw-r--r--kernel/uapi/ethosu.h12
-rw-r--r--utils/inference_runner/inference_runner.cpp27
13 files changed, 424 insertions, 23 deletions
diff --git a/driver_library/include/ethosu.hpp b/driver_library/include/ethosu.hpp
index 61e2bc5..da8dbbd 100644
--- a/driver_library/include/ethosu.hpp
+++ b/driver_library/include/ethosu.hpp
@@ -185,6 +185,8 @@ enum class InferenceStatus {
ERROR,
RUNNING,
REJECTED,
+ ABORTED,
+ ABORTING,
};
std::ostream &operator<<(std::ostream &out, const InferenceStatus &v);
@@ -226,9 +228,10 @@ public:
virtual ~Inference() noexcept(false);
- int wait(int64_t timeoutNanos = -1) const;
+ bool wait(int64_t timeoutNanos = -1) const;
const std::vector<uint32_t> getPmuCounters() const;
uint64_t getCycleCounter() const;
+ bool cancel() const;
InferenceStatus status() const;
int getFd() const;
const std::shared_ptr<Network> getNetwork() const;
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp
index 0da30c3..eeb95c2 100644
--- a/driver_library/src/ethosu.cpp
+++ b/driver_library/src/ethosu.cpp
@@ -340,6 +340,10 @@ ostream &operator<<(ostream &out, const InferenceStatus &status) {
return out << "running";
case InferenceStatus::REJECTED:
return out << "rejected";
+ case InferenceStatus::ABORTED:
+ return out << "aborted";
+ case InferenceStatus::ABORTING:
+ return out << "aborting";
}
throw Exception("Unknown inference status");
}
@@ -390,7 +394,7 @@ uint32_t Inference::getMaxPmuEventCounters() {
return ETHOSU_PMU_EVENT_MAX;
}
-int Inference::wait(int64_t timeoutNanos) const {
+bool Inference::wait(int64_t timeoutNanos) const {
struct pollfd pfd;
pfd.fd = fd;
pfd.events = POLLIN | POLLERR;
@@ -406,7 +410,13 @@ int Inference::wait(int64_t timeoutNanos) const {
tmo_p.tv_sec = timeoutNanos / nanosec;
tmo_p.tv_nsec = timeoutNanos % nanosec;
- return eppoll(&pfd, 1, &tmo_p, NULL);
+ return eppoll(&pfd, 1, &tmo_p, NULL) == 0;
+}
+
+bool Inference::cancel() const {
+ ethosu_uapi_cancel_inference_status uapi;
+ eioctl(fd, ETHOSU_IOCTL_INFERENCE_CANCEL, static_cast<void *>(&uapi));
+ return uapi.status == ETHOSU_UAPI_STATUS_OK;
}
InferenceStatus Inference::status() const {
@@ -423,6 +433,10 @@ InferenceStatus Inference::status() const {
return InferenceStatus::RUNNING;
case ETHOSU_UAPI_STATUS_REJECTED:
return InferenceStatus::REJECTED;
+ case ETHOSU_UAPI_STATUS_ABORTED:
+ return InferenceStatus::ABORTED;
+ case ETHOSU_UAPI_STATUS_ABORTING:
+ return InferenceStatus::ABORTING;
}
throw Exception("Unknown inference status");
diff --git a/kernel/Kbuild b/kernel/Kbuild
index 0b92c12..d54ec2c 100644
--- a/kernel/Kbuild
+++ b/kernel/Kbuild
@@ -27,4 +27,5 @@ ethosu-objs := ethosu_driver.o \
ethosu_mailbox.o \
ethosu_network.o \
ethosu_network_info.o \
- ethosu_watchdog.o
+ ethosu_watchdog.o \
+ ethosu_cancel_inference.o
diff --git a/kernel/ethosu_cancel_inference.c b/kernel/ethosu_cancel_inference.c
new file mode 100644
index 0000000..09778ee
--- /dev/null
+++ b/kernel/ethosu_cancel_inference.c
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * This program is free software and is provided to you under the terms of the
+ * GNU General Public License version 2 as published by the Free Software
+ * Foundation, and any use by you of this program is subject to the terms
+ * of such GNU licence.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, you can access it online at
+ * http://www.gnu.org/licenses/gpl-2.0.html.
+ *
+ * SPDX-License-Identifier: GPL-2.0-only
+ */
+
+/****************************************************************************
+ * Includes
+ ****************************************************************************/
+
+#include "ethosu_cancel_inference.h"
+
+#include "ethosu_core_interface.h"
+#include "ethosu_device.h"
+#include "ethosu_inference.h"
+
+#include <linux/wait.h>
+
+/****************************************************************************
+ * Defines
+ ****************************************************************************/
+
+#define CANCEL_INFERENCE_RESP_TIMEOUT_MS 2000
+
+/****************************************************************************
+ * Functions
+ ****************************************************************************/
+
+static void ethosu_cancel_inference_destroy(struct kref *kref)
+{
+ struct ethosu_cancel_inference *cancellation =
+ container_of(kref, struct ethosu_cancel_inference, kref);
+
+ dev_info(cancellation->edev->dev,
+ "Cancel inference destroy. handle=0x%p\n", cancellation);
+ list_del(&cancellation->msg.list);
+ /* decrease the reference on the inference we are refering to */
+ ethosu_inference_put(cancellation->inf);
+ devm_kfree(cancellation->edev->dev, cancellation);
+}
+
+static int ethosu_cancel_inference_send(
+ struct ethosu_cancel_inference *cancellation)
+{
+ return ethosu_mailbox_cancel_inference(&cancellation->edev->mailbox,
+ cancellation, cancellation->inf);
+}
+
+static void ethosu_cancel_inference_fail(struct ethosu_mailbox_msg *msg)
+{
+ struct ethosu_cancel_inference *cancellation =
+ container_of(msg, typeof(*cancellation), msg);
+
+ if (completion_done(&cancellation->done))
+ return;
+
+ cancellation->errno = -EFAULT;
+ cancellation->uapi->status = ETHOSU_UAPI_STATUS_ERROR;
+ complete(&cancellation->done);
+}
+
+static int ethosu_cancel_inference_complete(struct ethosu_mailbox_msg *msg)
+{
+ struct ethosu_cancel_inference *cancellation =
+ container_of(msg, typeof(*cancellation), msg);
+
+ if (completion_done(&cancellation->done))
+ return 0;
+
+ cancellation->errno = 0;
+ cancellation->uapi->status =
+ cancellation->inf->done &&
+ cancellation->inf->status != ETHOSU_UAPI_STATUS_OK ?
+ ETHOSU_UAPI_STATUS_OK :
+ ETHOSU_UAPI_STATUS_ERROR;
+ complete(&cancellation->done);
+
+ return 0;
+}
+
+int ethosu_cancel_inference_request(struct ethosu_inference *inf,
+ struct ethosu_uapi_cancel_inference_status *uapi)
+{
+ struct ethosu_cancel_inference *cancellation;
+ int ret;
+ int timeout;
+
+ if (inf->done) {
+ uapi->status = ETHOSU_UAPI_STATUS_ERROR;
+
+ return 0;
+ }
+
+ cancellation =
+ devm_kzalloc(inf->edev->dev,
+ sizeof(struct ethosu_cancel_inference),
+ GFP_KERNEL);
+ if (!cancellation)
+ return -ENOMEM;
+
+ /* increase ref count on the inference we are refering to */
+ ethosu_inference_get(inf);
+ /* mark inference ABORTING to avoid resending the inference message */
+ inf->status = ETHOSU_CORE_STATUS_ABORTING;
+
+ cancellation->edev = inf->edev;
+ cancellation->inf = inf;
+ cancellation->uapi = uapi;
+ kref_init(&cancellation->kref);
+ init_completion(&cancellation->done);
+ cancellation->msg.fail = ethosu_cancel_inference_fail;
+
+ /* Never resend messages but always complete, since we have restart the
+ * whole firmware and marked the inference as aborted */
+ cancellation->msg.resend = ethosu_cancel_inference_complete;
+
+ /* Add cancel inference to pending list */
+ list_add(&cancellation->msg.list,
+ &cancellation->edev->mailbox.pending_list);
+
+ ret = ethosu_cancel_inference_send(cancellation);
+ if (0 != ret)
+ goto put_kref;
+
+ /* Unlock the mutex before going to block on the condition */
+ mutex_unlock(&cancellation->edev->mutex);
+ /* wait for response to arrive back */
+ timeout = wait_for_completion_timeout(&cancellation->done,
+ msecs_to_jiffies(
+ CANCEL_INFERENCE_RESP_TIMEOUT_MS));
+ /* take back the mutex before resuming to do anything */
+ ret = mutex_lock_interruptible(&cancellation->edev->mutex);
+ if (0 != ret)
+ goto put_kref;
+
+ if (0 == timeout /* timed out*/) {
+ dev_warn(inf->edev->dev,
+ "Msg: Cancel Inference response lost - timeout\n");
+ ret = -EIO;
+ goto put_kref;
+ }
+
+ if (cancellation->errno) {
+ ret = cancellation->errno;
+ goto put_kref;
+ }
+
+put_kref:
+ kref_put(&cancellation->kref, &ethosu_cancel_inference_destroy);
+
+ return ret;
+}
+
+void ethosu_cancel_inference_rsp(struct ethosu_device *edev,
+ struct ethosu_core_cancel_inference_rsp *rsp)
+{
+ struct ethosu_cancel_inference *cancellation =
+ (struct ethosu_cancel_inference *)rsp->user_arg;
+ int ret;
+
+ ret = ethosu_mailbox_find(&edev->mailbox, &cancellation->msg);
+ if (ret) {
+ dev_warn(edev->dev,
+ "Handle not found in cancel inference list. handle=0x%p\n",
+ rsp);
+
+ return;
+ }
+
+ if (completion_done(&cancellation->done))
+ return;
+
+ cancellation->errno = 0;
+ switch (rsp->status) {
+ case ETHOSU_CORE_STATUS_OK:
+ cancellation->uapi->status = ETHOSU_UAPI_STATUS_OK;
+ break;
+ case ETHOSU_CORE_STATUS_ERROR:
+ cancellation->uapi->status = ETHOSU_UAPI_STATUS_ERROR;
+ break;
+ }
+
+ complete(&cancellation->done);
+}
diff --git a/kernel/ethosu_cancel_inference.h b/kernel/ethosu_cancel_inference.h
new file mode 100644
index 0000000..94d9fe1
--- /dev/null
+++ b/kernel/ethosu_cancel_inference.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2022 ARM Limited.
+ *
+ * This program is free software and is provided to you under the terms of the
+ * GNU General Public License version 2 as published by the Free Software
+ * Foundation, and any use by you of this program is subject to the terms
+ * of such GNU licence.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, you can access it online at
+ * http://www.gnu.org/licenses/gpl-2.0.html.
+ *
+ * SPDX-License-Identifier: GPL-2.0-only
+ */
+
+#ifndef ETHOSU_CANCEL_INFERENCE_H
+#define ETHOSU_CANCEL_INFERENCE_H
+
+/****************************************************************************
+ * Includes
+ ****************************************************************************/
+
+#include "ethosu_mailbox.h"
+#include "uapi/ethosu.h"
+
+#include <linux/kref.h>
+#include <linux/types.h>
+#include <linux/completion.h>
+
+/****************************************************************************
+ * Types
+ ****************************************************************************/
+
+struct ethosu_core_cancel_inference_rsp;
+struct ethosu_device;
+struct ethosu_uapi_cancel_inference_status;
+struct ethosu_inference;
+
+struct ethosu_cancel_inference {
+ struct ethosu_device *edev;
+ struct ethosu_inference *inf;
+ struct ethosu_uapi_cancel_inference_status *uapi;
+ struct kref kref;
+ struct completion done;
+ struct ethosu_mailbox_msg msg;
+ int errno;
+};
+
+/****************************************************************************
+ * Functions
+ ****************************************************************************/
+
+/**
+ * ethosu_cancel_inference_request() - Send cancel inference request
+ *
+ * Return: 0 on success, error code otherwise.
+ */
+int ethosu_cancel_inference_request(struct ethosu_inference *inf,
+ struct ethosu_uapi_cancel_inference_status *uapi);
+
+/**
+ * ethosu_cancel_inference_rsp() - Handle cancel inference response
+ */
+void ethosu_cancel_inference_rsp(struct ethosu_device *edev,
+ struct ethosu_core_cancel_inference_rsp *rsp);
+
+#endif /* ETHOSU_CANCEL_INFERENCE_H */
diff --git a/kernel/ethosu_core_interface.h b/kernel/ethosu_core_interface.h
index 76fe35e..057e3c2 100644
--- a/kernel/ethosu_core_interface.h
+++ b/kernel/ethosu_core_interface.h
@@ -59,6 +59,8 @@ enum ethosu_core_msg_type {
ETHOSU_CORE_MSG_CAPABILITIES_RSP,
ETHOSU_CORE_MSG_NETWORK_INFO_REQ,
ETHOSU_CORE_MSG_NETWORK_INFO_RSP,
+ ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ,
+ ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP,
ETHOSU_CORE_MSG_MAX
};
@@ -98,6 +100,8 @@ enum ethosu_core_status {
ETHOSU_CORE_STATUS_ERROR,
ETHOSU_CORE_STATUS_RUNNING,
ETHOSU_CORE_STATUS_REJECTED,
+ ETHOSU_CORE_STATUS_ABORTED,
+ ETHOSU_CORE_STATUS_ABORTING,
};
/**
@@ -217,6 +221,22 @@ struct ethosu_core_msg_capabilities_rsp {
};
/**
+ * struct ethosu_core_cancel_inference_req - Message cancel inference request
+ */
+struct ethosu_core_cancel_inference_req {
+ uint64_t user_arg;
+ uint64_t inference_handle;
+};
+
+/**
+ * struct ethosu_core_cancel_inference_rsp - Message cancel inference response
+ */
+struct ethosu_core_cancel_inference_rsp {
+ uint64_t user_arg;
+ uint32_t status;
+};
+
+/**
* enum ethosu_core_msg_err_type - Error types
*/
enum ethosu_core_msg_err_type {
diff --git a/kernel/ethosu_device.c b/kernel/ethosu_device.c
index 316df6f..f66c2ac 100644
--- a/kernel/ethosu_device.c
+++ b/kernel/ethosu_device.c
@@ -27,6 +27,7 @@
#include "ethosu_buffer.h"
#include "ethosu_core_interface.h"
#include "ethosu_inference.h"
+#include "ethosu_cancel_inference.h"
#include "ethosu_network.h"
#include "ethosu_network_info.h"
#include "uapi/ethosu.h"
@@ -209,6 +210,7 @@ static int ethosu_handle_msg(struct ethosu_device *edev)
struct ethosu_core_msg_version version;
struct ethosu_core_msg_capabilities_rsp capabilities;
struct ethosu_core_network_info_rsp network_info;
+ struct ethosu_core_cancel_inference_rsp cancellation;
} data;
/* Read message */
@@ -254,6 +256,20 @@ static int ethosu_handle_msg(struct ethosu_device *edev)
ethosu_inference_rsp(edev, &data.inf);
break;
+ case ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP:
+ if (header.length != sizeof(data.cancellation)) {
+ dev_warn(edev->dev,
+ "Msg: Cancel Inference response of incorrect size. size=%u, expected=%zu\n", header.length,
+ sizeof(data.cancellation));
+ ret = -EBADMSG;
+ break;
+ }
+
+ dev_info(edev->dev,
+ "Msg: Cancel Inference response. user_arg=0x%llx, status=%u\n",
+ data.cancellation.user_arg, data.cancellation.status);
+ ethosu_cancel_inference_rsp(edev, &data.cancellation);
+ break;
case ETHOSU_CORE_MSG_VERSION_RSP:
if (header.length != sizeof(data.version)) {
dev_warn(edev->dev,
diff --git a/kernel/ethosu_inference.c b/kernel/ethosu_inference.c
index 73a8c06..0599b53 100644
--- a/kernel/ethosu_inference.c
+++ b/kernel/ethosu_inference.c
@@ -28,6 +28,7 @@
#include "ethosu_core_interface.h"
#include "ethosu_device.h"
#include "ethosu_network.h"
+#include "ethosu_cancel_inference.h"
#include <linux/anon_inodes.h>
#include <linux/file.h>
@@ -76,6 +77,12 @@ static const char *status_to_string(const enum ethosu_uapi_status status)
case ETHOSU_UAPI_STATUS_REJECTED: {
return "Rejected";
}
+ case ETHOSU_UAPI_STATUS_ABORTED: {
+ return "Aborted";
+ }
+ case ETHOSU_UAPI_STATUS_ABORTING: {
+ return "Aborting";
+ }
default: {
return "Unknown";
}
@@ -112,15 +119,19 @@ static void ethosu_inference_fail(struct ethosu_mailbox_msg *msg)
container_of(msg, typeof(*inf), msg);
int ret;
+ if (inf->done)
+ return;
+
/* Decrement reference count if inference was pending reponse */
- if (!inf->done) {
- ret = ethosu_inference_put(inf);
- if (ret)
- return;
- }
+ ret = ethosu_inference_put(inf);
+ if (ret)
+ return;
- /* Fail inference and wake up any waiting process */
- inf->status = ETHOSU_UAPI_STATUS_ERROR;
+ /* Set status accordingly to the inference state */
+ inf->status = inf->status == ETHOSU_UAPI_STATUS_ABORTING ?
+ ETHOSU_UAPI_STATUS_ABORTED :
+ ETHOSU_UAPI_STATUS_ERROR;
+ /* Mark it done and wake up the waiting process */
inf->done = true;
wake_up_interruptible(&inf->waitq);
}
@@ -135,6 +146,13 @@ static int ethosu_inference_resend(struct ethosu_mailbox_msg *msg)
if (inf->done)
return 0;
+ /* If marked as ABORTING simply fail it and return */
+ if (inf->status == ETHOSU_UAPI_STATUS_ABORTING) {
+ ethosu_inference_fail(msg);
+
+ return 0;
+ }
+
/* Decrement reference count for pending request */
ret = ethosu_inference_put(inf);
if (ret)
@@ -241,8 +259,22 @@ static long ethosu_inference_ioctl(struct file *file,
break;
}
+ case ETHOSU_IOCTL_INFERENCE_CANCEL: {
+ struct ethosu_uapi_cancel_inference_status uapi;
+
+ dev_info(inf->edev->dev, "Ioctl: Cancel Inference. Handle=%p\n",
+ inf);
+
+ ret = ethosu_cancel_inference_request(inf, &uapi);
+ if (ret)
+ break;
+
+ ret = copy_to_user(udata, &uapi, sizeof(uapi)) ? -EFAULT : 0;
+
+ break;
+ }
default: {
- dev_err(inf->edev->dev, "Invalid ioctl. cmd=%u, arg=%lu",
+ dev_err(inf->edev->dev, "Invalid ioctl. cmd=%u, arg=%lu\n",
cmd, arg);
break;
}
@@ -422,6 +454,8 @@ void ethosu_inference_rsp(struct ethosu_device *edev,
}
} else if (rsp->status == ETHOSU_CORE_STATUS_REJECTED) {
inf->status = ETHOSU_UAPI_STATUS_REJECTED;
+ } else if (rsp->status == ETHOSU_CORE_STATUS_ABORTED) {
+ inf->status = ETHOSU_UAPI_STATUS_ABORTED;
} else {
inf->status = ETHOSU_UAPI_STATUS_ERROR;
}
@@ -450,6 +484,5 @@ void ethosu_inference_rsp(struct ethosu_device *edev,
inf->done = true;
wake_up_interruptible(&inf->waitq);
-
ethosu_inference_put(inf);
}
diff --git a/kernel/ethosu_inference.h b/kernel/ethosu_inference.h
index 66d4ff9..2f188b6 100644
--- a/kernel/ethosu_inference.h
+++ b/kernel/ethosu_inference.h
@@ -30,7 +30,6 @@
#include <linux/kref.h>
#include <linux/types.h>
-#include <linux/wait.h>
/****************************************************************************
* Types
@@ -58,6 +57,7 @@ struct file;
* @pmu_event_count: PMU event count after inference
* @pmu_cycle_counter_enable: PMU cycle counter config
* @pmu_cycle_counter_count: PMU cycle counter count after inference
+ * @msg: Mailbox message
*/
struct ethosu_inference {
struct ethosu_device *edev;
diff --git a/kernel/ethosu_mailbox.c b/kernel/ethosu_mailbox.c
index 7355361..5343e56 100644
--- a/kernel/ethosu_mailbox.c
+++ b/kernel/ethosu_mailbox.c
@@ -418,6 +418,20 @@ int ethosu_mailbox_network_info_request(struct ethosu_mailbox *mbox,
&info, sizeof(info));
}
+int ethosu_mailbox_cancel_inference(struct ethosu_mailbox *mbox,
+ void *user_arg,
+ void *inference_handle)
+{
+ struct ethosu_core_cancel_inference_req req;
+
+ req.user_arg = (ptrdiff_t)user_arg;
+ req.inference_handle = (ptrdiff_t)inference_handle;
+
+ return ethosu_queue_write_msg(mbox,
+ ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ,
+ &req, sizeof(req));
+}
+
static void ethosu_mailbox_rx_work(struct work_struct *work)
{
struct ethosu_mailbox *mbox = container_of(work, typeof(*mbox), work);
diff --git a/kernel/ethosu_mailbox.h b/kernel/ethosu_mailbox.h
index 55d4436..7af0c47 100644
--- a/kernel/ethosu_mailbox.h
+++ b/kernel/ethosu_mailbox.h
@@ -203,4 +203,13 @@ int ethosu_mailbox_network_info_request(struct ethosu_mailbox *mbox,
struct ethosu_buffer *network,
uint32_t network_index);
+/**
+ * ethosu_mailbox_cancel_inference() - Send inference cancellation
+ *
+ * Return: 0 on success, else error code.
+ */
+int ethosu_mailbox_cancel_inference(struct ethosu_mailbox *mbox,
+ void *user_arg,
+ void *inference_handle);
+
#endif /* ETHOSU_MAILBOX_H */
diff --git a/kernel/uapi/ethosu.h b/kernel/uapi/ethosu.h
index 4627cb9..033afe6 100644
--- a/kernel/uapi/ethosu.h
+++ b/kernel/uapi/ethosu.h
@@ -59,6 +59,8 @@ namespace EthosU {
struct ethosu_uapi_inference_create)
#define ETHOSU_IOCTL_INFERENCE_STATUS ETHOSU_IOR(0x31, \
struct ethosu_uapi_result_status)
+#define ETHOSU_IOCTL_INFERENCE_CANCEL ETHOSU_IOR(0x32, \
+ struct ethosu_uapi_cancel_inference_status)
/* Maximum number of IFM/OFM file descriptors per network */
#define ETHOSU_FD_MAX 16
@@ -78,6 +80,8 @@ enum ethosu_uapi_status {
ETHOSU_UAPI_STATUS_ERROR,
ETHOSU_UAPI_STATUS_RUNNING,
ETHOSU_UAPI_STATUS_REJECTED,
+ ETHOSU_UAPI_STATUS_ABORTED,
+ ETHOSU_UAPI_STATUS_ABORTING,
};
/**
@@ -238,6 +242,14 @@ struct ethosu_uapi_result_status {
struct ethosu_uapi_pmu_counts pmu_count;
};
+/**
+ * struct ethosu_uapi_cancel_status - Status of inference cancellation.
+ * @status OK if inference cancellation was performed, ERROR otherwise.
+ */
+struct ethosu_uapi_cancel_inference_status {
+ enum ethosu_uapi_status status;
+};
+
#ifdef __cplusplus
} /* namespace EthosU */
#endif
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp
index 21e133c..9df67ed 100644
--- a/utils/inference_runner/inference_runner.cpp
+++ b/utils/inference_runner/inference_runner.cpp
@@ -259,22 +259,29 @@ int main(int argc, char *argv[]) {
/* make sure the wait completes ok */
try {
cout << "Wait for inference" << endl;
- inference->wait(timeout);
+ bool timedout = inference->wait(timeout);
+ if (timedout) {
+ cout << "Inference timed out, cancelling it" << endl;
+ bool aborted = inference->cancel();
+ if (!aborted || inference->status() != InferenceStatus::ABORTED) {
+ cout << "Inference cancellation failed" << endl;
+ }
+ }
} catch (std::exception &e) {
- cout << "Failed to wait for inference completion: " << e.what() << endl;
+ cout << "Failed to wait for or to cancel inference: " << e.what() << endl;
exit(1);
}
cout << "Inference status: " << inference->status() << endl;
- string ofmFilename = ofmArg + "." + to_string(ofmIndex);
- ofstream ofmStream(ofmFilename, ios::binary);
- if (!ofmStream.is_open()) {
- cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
- exit(1);
- }
-
if (inference->status() == InferenceStatus::OK) {
+ string ofmFilename = ofmArg + "." + to_string(ofmIndex);
+ ofstream ofmStream(ofmFilename, ios::binary);
+ if (!ofmStream.is_open()) {
+ cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
+ exit(1);
+ }
+
/* The inference completed and has ok status */
for (auto &ofmBuffer : inference->getOfmBuffers()) {
cout << "OFM size: " << ofmBuffer->size() << endl;
@@ -286,6 +293,8 @@ int main(int argc, char *argv[]) {
ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
+ ofmStream.flush();
+
/* Read out PMU counters if configured */
if (std::count(enabledCounters.begin(), enabledCounters.end(), 0) <
Inference::getMaxPmuEventCounters()) {