aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavide Grohmann <davide.grohmann@arm.com>2022-03-23 12:48:45 +0100
committerDavide Grohmann <davide.grohmann@arm.com>2022-05-05 11:13:04 +0200
commit7e8f508765632c42cc44fd8ad704c9d90943ab32 (patch)
tree42dcfb929accf5470d6aa61810da20356c39eb75
parent82d225899bd3d4fd07d70cac80f50c1b288dc4a3 (diff)
downloadethos-u-linux-driver-stack-7e8f508765632c42cc44fd8ad704c9d90943ab32.tar.gz
Add support for inference cancellation
Send cancel inference messages to the ethosu subsystem to abort inference execution there. Also mark inference as aborted in the linux driver stack itself, so pending inference messages are not resent when resetting the firmware. Change-Id: I244c2b119fd7995d14e3859815abf2a00c7f0583
-rw-r--r--driver_library/include/ethosu.hpp5
-rw-r--r--driver_library/src/ethosu.cpp18
-rw-r--r--kernel/Kbuild3
-rw-r--r--kernel/ethosu_cancel_inference.c198
-rw-r--r--kernel/ethosu_cancel_inference.h72
-rw-r--r--kernel/ethosu_core_interface.h20
-rw-r--r--kernel/ethosu_device.c16
-rw-r--r--kernel/ethosu_inference.c51
-rw-r--r--kernel/ethosu_inference.h2
-rw-r--r--kernel/ethosu_mailbox.c14
-rw-r--r--kernel/ethosu_mailbox.h9
-rw-r--r--kernel/uapi/ethosu.h12
-rw-r--r--utils/inference_runner/inference_runner.cpp27
13 files changed, 424 insertions, 23 deletions
diff --git a/driver_library/include/ethosu.hpp b/driver_library/include/ethosu.hpp
index 61e2bc5..da8dbbd 100644
--- a/driver_library/include/ethosu.hpp
+++ b/driver_library/include/ethosu.hpp
@@ -185,6 +185,8 @@ enum class InferenceStatus {
ERROR,
RUNNING,
REJECTED,
+ ABORTED,
+ ABORTING,
};
std::ostream &operator<<(std::ostream &out, const InferenceStatus &v);
@@ -226,9 +228,10 @@ public:
virtual ~Inference() noexcept(false);
- int wait(int64_t timeoutNanos = -1) const;
+ bool wait(int64_t timeoutNanos = -1) const;
const std::vector<uint32_t> getPmuCounters() const;
uint64_t getCycleCounter() const;
+ bool cancel() const;
InferenceStatus status() const;
int getFd() const;
const std::shared_ptr<Network> getNetwork() const;
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp
index 0da30c3..eeb95c2 100644
--- a/driver_library/src/ethosu.cpp
+++ b/driver_library/src/ethosu.cpp
@@ -340,6 +340,10 @@ ostream &operator<<(ostream &out, const InferenceStatus &status) {
return out << "running";
case InferenceStatus::REJECTED:
return out << "rejected";
+ case InferenceStatus::ABORTED:
+ return out << "aborted";
+ case InferenceStatus::ABORTING:
+ return out << "aborting";
}
throw Exception("Unknown inference status");
}
@@ -390,7 +394,7 @@ uint32_t Inference::getMaxPmuEventCounters() {
return ETHOSU_PMU_EVENT_MAX;
}
-int Inference::wait(int64_t timeoutNanos) const {
+bool Inference::wait(int64_t timeoutNanos) const {
struct pollfd pfd;
pfd.fd = fd;
pfd.events = POLLIN | POLLERR;
@@ -406,7 +410,13 @@ int Inference::wait(int64_t timeoutNanos) const {
tmo_p.tv_sec = timeoutNanos / nanosec;
tmo_p.tv_nsec = timeoutNanos % nanosec;
- return eppoll(&pfd, 1, &tmo_p, NULL);
+ return eppoll(&pfd, 1, &tmo_p, NULL) == 0;
+}
+
+bool Inference::cancel() const {
+ ethosu_uapi_cancel_inference_status uapi;
+ eioctl(fd, ETHOSU_IOCTL_INFERENCE_CANCEL, static_cast<void *>(&uapi));
+ return uapi.status == ETHOSU_UAPI_STATUS_OK;
}
InferenceStatus Inference::status() const {
@@ -423,6 +433,10 @@ InferenceStatus Inference::status() const {
return InferenceStatus::RUNNING;
case ETHOSU_UAPI_STATUS_REJECTED:
return InferenceStatus::REJECTED;
+ case ETHOSU_UAPI_STATUS_ABORTED:
+ return InferenceStatus::ABORTED;
+ case ETHOSU_UAPI_STATUS_ABORTING:
+ return InferenceStatus::ABORTING;
}
throw Exception("Unknown inference status");
diff --git a/kernel/Kbuild b/kernel/Kbuild
index 0b92c12..d54ec2c 100644
--- a/kernel/Kbuild
+++ b/kernel/Kbuild
@@ -27,4 +27,5 @@ ethosu-objs := ethosu_driver.o \
ethosu_mailbox.o \
ethosu_network.o \
ethosu_network_info.o \
- ethosu_watchdog.o
+ ethosu_watchdog.o \
+ ethosu_cancel_inference.o
diff --git a/kernel/ethosu_cancel_inference.c b/kernel/ethosu_cancel_inference.c
new file mode 100644
index 0000000..09778ee
--- /dev/null
+++ b/kernel/ethosu_cancel_inference.c
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * This program is free software and is provided to you under the terms of the
+ * GNU General Public License version 2 as published by the Free Software
+ * Foundation, and any use by you of this program is subject to the terms
+ * of such GNU licence.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, you can access it online at
+ * http://www.gnu.org/licenses/gpl-2.0.html.
+ *
+ * SPDX-License-Identifier: GPL-2.0-only
+ */
+
+/****************************************************************************
+ * Includes
+ ****************************************************************************/
+
+#include "ethosu_cancel_inference.h"
+
+#include "ethosu_core_interface.h"
+#include "ethosu_device.h"
+#include "ethosu_inference.h"
+
+#include <linux/wait.h>
+
+/****************************************************************************
+ * Defines
+ ****************************************************************************/
+
+#define CANCEL_INFERENCE_RESP_TIMEOUT_MS 2000
+
+/****************************************************************************
+ * Functions
+ ****************************************************************************/
+
+static void ethosu_cancel_inference_destroy(struct kref *kref)
+{
+ struct ethosu_cancel_inference *cancellation =
+ container_of(kref, struct ethosu_cancel_inference, kref);
+
+ dev_info(cancellation->edev->dev,
+ "Cancel inference destroy. handle=0x%p\n", cancellation);
+ list_del(&cancellation->msg.list);
+ /* decrease the reference on the inference we are refering to */
+ ethosu_inference_put(cancellation->inf);
+ devm_kfree(cancellation->edev->dev, cancellation);
+}
+
+static int ethosu_cancel_inference_send(
+ struct ethosu_cancel_inference *cancellation)
+{
+ return ethosu_mailbox_cancel_inference(&cancellation->edev->mailbox,
+ cancellation, cancellation->inf);
+}
+
+static void ethosu_cancel_inference_fail(struct ethosu_mailbox_msg *msg)
+{
+ struct ethosu_cancel_inference *cancellation =
+ container_of(msg, typeof(*cancellation), msg);
+
+ if (completion_done(&cancellation->done))
+ return;
+
+ cancellation->errno = -EFAULT;
+ cancellation->uapi->status = ETHOSU_UAPI_STATUS_ERROR;
+ complete(&cancellation->done);
+}
+
+static int ethosu_cancel_inference_complete(struct ethosu_mailbox_msg *msg)
+{
+ struct ethosu_cancel_inference *cancellation =
+ container_of(msg, typeof(*cancellation), msg);
+
+ if (completion_done(&cancellation->done))
+ return 0;
+
+ cancellation->errno = 0;
+ cancellation->uapi->status =
+ cancellation->inf->done &&
+ cancellation->inf->status != ETHOSU_UAPI_STATUS_OK ?
+ ETHOSU_UAPI_STATUS_OK :
+ ETHOSU_UAPI_STATUS_ERROR;
+ complete(&cancellation->done);
+
+ return 0;
+}
+
+int ethosu_cancel_inference_request(struct ethosu_inference *inf,
+ struct ethosu_uapi_cancel_inference_status *uapi)
+{
+ struct ethosu_cancel_inference *cancellation;
+ int ret;
+ int timeout;
+
+ if (inf->done) {
+ uapi->status = ETHOSU_UAPI_STATUS_ERROR;
+
+ return 0;
+ }
+
+ cancellation =
+ devm_kzalloc(inf->edev->dev,
+ sizeof(struct ethosu_cancel_inference),
+ GFP_KERNEL);
+ if (!cancellation)
+ return -ENOMEM;
+
+ /* increase ref count on the inference we are refering to */
+ ethosu_inference_get(inf);
+ /* mark inference ABORTING to avoid resending the inference message */
+ inf->status = ETHOSU_CORE_STATUS_ABORTING;
+
+ cancellation->edev = inf->edev;
+ cancellation->inf = inf;
+ cancellation->uapi = uapi;
+ kref_init(&cancellation->kref);
+ init_completion(&cancellation->done);
+ cancellation->msg.fail = ethosu_cancel_inference_fail;
+
+ /* Never resend messages but always complete, since we have restart the
+ * whole firmware and marked the inference as aborted */
+ cancellation->msg.resend = ethosu_cancel_inference_complete;
+
+ /* Add cancel inference to pending list */
+ list_add(&cancellation->msg.list,
+ &cancellation->edev->mailbox.pending_list);
+
+ ret = ethosu_cancel_inference_send(cancellation);
+ if (0 != ret)
+ goto put_kref;
+
+ /* Unlock the mutex before going to block on the condition */
+ mutex_unlock(&cancellation->edev->mutex);
+ /* wait for response to arrive back */
+ timeout = wait_for_completion_timeout(&cancellation->done,
+ msecs_to_jiffies(
+ CANCEL_INFERENCE_RESP_TIMEOUT_MS));
+ /* take back the mutex before resuming to do anything */
+ ret = mutex_lock_interruptible(&cancellation->edev->mutex);
+ if (0 != ret)
+ goto put_kref;
+
+ if (0 == timeout /* timed out*/) {
+ dev_warn(inf->edev->dev,
+ "Msg: Cancel Inference response lost - timeout\n");
+ ret = -EIO;
+ goto put_kref;
+ }
+
+ if (cancellation->errno) {
+ ret = cancellation->errno;
+ goto put_kref;
+ }
+
+put_kref:
+ kref_put(&cancellation->kref, &ethosu_cancel_inference_destroy);
+
+ return ret;
+}
+
+void ethosu_cancel_inference_rsp(struct ethosu_device *edev,
+ struct ethosu_core_cancel_inference_rsp *rsp)
+{
+ struct ethosu_cancel_inference *cancellation =
+ (struct ethosu_cancel_inference *)rsp->user_arg;
+ int ret;
+
+ ret = ethosu_mailbox_find(&edev->mailbox, &cancellation->msg);
+ if (ret) {
+ dev_warn(edev->dev,
+ "Handle not found in cancel inference list. handle=0x%p\n",
+ rsp);
+
+ return;
+ }
+
+ if (completion_done(&cancellation->done))
+ return;
+
+ cancellation->errno = 0;
+ switch (rsp->status) {
+ case ETHOSU_CORE_STATUS_OK:
+ cancellation->uapi->status = ETHOSU_UAPI_STATUS_OK;
+ break;
+ case ETHOSU_CORE_STATUS_ERROR:
+ cancellation->uapi->status = ETHOSU_UAPI_STATUS_ERROR;
+ break;
+ }
+
+ complete(&cancellation->done);
+}
diff --git a/kernel/ethosu_cancel_inference.h b/kernel/ethosu_cancel_inference.h
new file mode 100644
index 0000000..94d9fe1
--- /dev/null
+++ b/kernel/ethosu_cancel_inference.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2022 ARM Limited.
+ *
+ * This program is free software and is provided to you under the terms of the
+ * GNU General Public License version 2 as published by the Free Software
+ * Foundation, and any use by you of this program is subject to the terms
+ * of such GNU licence.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, you can access it online at
+ * http://www.gnu.org/licenses/gpl-2.0.html.
+ *
+ * SPDX-License-Identifier: GPL-2.0-only
+ */
+
+#ifndef ETHOSU_CANCEL_INFERENCE_H
+#define ETHOSU_CANCEL_INFERENCE_H
+
+/****************************************************************************
+ * Includes
+ ****************************************************************************/
+
+#include "ethosu_mailbox.h"
+#include "uapi/ethosu.h"
+
+#include <linux/kref.h>
+#include <linux/types.h>
+#include <linux/completion.h>
+
+/****************************************************************************
+ * Types
+ ****************************************************************************/
+
+struct ethosu_core_cancel_inference_rsp;
+struct ethosu_device;
+struct ethosu_uapi_cancel_inference_status;
+struct ethosu_inference;
+
+struct ethosu_cancel_inference {
+ struct ethosu_device *edev;
+ struct ethosu_inference *inf;
+ struct ethosu_uapi_cancel_inference_status *uapi;
+ struct kref kref;
+ struct completion done;
+ struct ethosu_mailbox_msg msg;
+ int errno;
+};
+
+/****************************************************************************
+ * Functions
+ ****************************************************************************/
+
+/**
+ * ethosu_cancel_inference_request() - Send cancel inference request
+ *
+ * Return: 0 on success, error code otherwise.
+ */
+int ethosu_cancel_inference_request(struct ethosu_inference *inf,
+ struct ethosu_uapi_cancel_inference_status *uapi);
+
+/**
+ * ethosu_cancel_inference_rsp() - Handle cancel inference response
+ */
+void ethosu_cancel_inference_rsp(struct ethosu_device *edev,
+ struct ethosu_core_cancel_inference_rsp *rsp);
+
+#endif /* ETHOSU_CANCEL_INFERENCE_H */
diff --git a/kernel/ethosu_core_interface.h b/kernel/ethosu_core_interface.h
index 76fe35e..057e3c2 100644
--- a/kernel/ethosu_core_interface.h
+++ b/kernel/ethosu_core_interface.h
@@ -59,6 +59,8 @@ enum ethosu_core_msg_type {
ETHOSU_CORE_MSG_CAPABILITIES_RSP,
ETHOSU_CORE_MSG_NETWORK_INFO_REQ,
ETHOSU_CORE_MSG_NETWORK_INFO_RSP,
+ ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ,
+ ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP,
ETHOSU_CORE_MSG_MAX
};
@@ -98,6 +100,8 @@ enum ethosu_core_status {
ETHOSU_CORE_STATUS_ERROR,
ETHOSU_CORE_STATUS_RUNNING,
ETHOSU_CORE_STATUS_REJECTED,
+ ETHOSU_CORE_STATUS_ABORTED,
+ ETHOSU_CORE_STATUS_ABORTING,
};
/**
@@ -217,6 +221,22 @@ struct ethosu_core_msg_capabilities_rsp {
};
/**
+ * struct ethosu_core_cancel_inference_req - Message cancel inference request
+ */
+struct ethosu_core_cancel_inference_req {
+ uint64_t user_arg;
+ uint64_t inference_handle;
+};
+
+/**
+ * struct ethosu_core_cancel_inference_rsp - Message cancel inference response
+ */
+struct ethosu_core_cancel_inference_rsp {
+ uint64_t user_arg;
+ uint32_t status;
+};
+
+/**
* enum ethosu_core_msg_err_type - Error types
*/
enum ethosu_core_msg_err_type {
diff --git a/kernel/ethosu_device.c b/kernel/ethosu_device.c
index 316df6f..f66c2ac 100644
--- a/kernel/ethosu_device.c
+++ b/kernel/ethosu_device.c
@@ -27,6 +27,7 @@
#include "ethosu_buffer.h"
#include "ethosu_core_interface.h"
#include "ethosu_inference.h"
+#include "ethosu_cancel_inference.h"
#include "ethosu_network.h"
#include "ethosu_network_info.h"
#include "uapi/ethosu.h"
@@ -209,6 +210,7 @@ static int ethosu_handle_msg(struct ethosu_device *edev)
struct ethosu_core_msg_version version;
struct ethosu_core_msg_capabilities_rsp capabilities;
struct ethosu_core_network_info_rsp network_info;
+ struct ethosu_core_cancel_inference_rsp cancellation;
} data;
/* Read message */
@@ -254,6 +256,20 @@ static int ethosu_handle_msg(struct ethosu_device *edev)
ethosu_inference_rsp(edev, &data.inf);
break;
+ case ETHOSU_CORE_MSG_CANCEL_INFERENCE_RSP:
+ if (header.length != sizeof(data.cancellation)) {
+ dev_warn(edev->dev,
+ "Msg: Cancel Inference response of incorrect size. size=%u, expected=%zu\n", header.length,
+ sizeof(data.cancellation));
+ ret = -EBADMSG;
+ break;
+ }
+
+ dev_info(edev->dev,
+ "Msg: Cancel Inference response. user_arg=0x%llx, status=%u\n",
+ data.cancellation.user_arg, data.cancellation.status);
+ ethosu_cancel_inference_rsp(edev, &data.cancellation);
+ break;
case ETHOSU_CORE_MSG_VERSION_RSP:
if (header.length != sizeof(data.version)) {
dev_warn(edev->dev,
diff --git a/kernel/ethosu_inference.c b/kernel/ethosu_inference.c
index 73a8c06..0599b53 100644
--- a/kernel/ethosu_inference.c
+++ b/kernel/ethosu_inference.c
@@ -28,6 +28,7 @@
#include "ethosu_core_interface.h"
#include "ethosu_device.h"
#include "ethosu_network.h"
+#include "ethosu_cancel_inference.h"
#include <linux/anon_inodes.h>
#include <linux/file.h>
@@ -76,6 +77,12 @@ static const char *status_to_string(const enum ethosu_uapi_status status)
case ETHOSU_UAPI_STATUS_REJECTED: {
return "Rejected";
}
+ case ETHOSU_UAPI_STATUS_ABORTED: {
+ return "Aborted";
+ }
+ case ETHOSU_UAPI_STATUS_ABORTING: {
+ return "Aborting";
+ }
default: {
return "Unknown";
}
@@ -112,15 +119,19 @@ static void ethosu_inference_fail(struct ethosu_mailbox_msg *msg)
container_of(msg, typeof(*inf), msg);
int ret;
+ if (inf->done)
+ return;
+
/* Decrement reference count if inference was pending reponse */
- if (!inf->done) {
- ret = ethosu_inference_put(inf);
- if (ret)
- return;
- }
+ ret = ethosu_inference_put(inf);
+ if (ret)
+ return;
- /* Fail inference and wake up any waiting process */
- inf->status = ETHOSU_UAPI_STATUS_ERROR;
+ /* Set status accordingly to the inference state */
+ inf->status = inf->status == ETHOSU_UAPI_STATUS_ABORTING ?
+ ETHOSU_UAPI_STATUS_ABORTED :
+ ETHOSU_UAPI_STATUS_ERROR;
+ /* Mark it done and wake up the waiting process */
inf->done = true;
wake_up_interruptible(&inf->waitq);
}
@@ -135,6 +146,13 @@ static int ethosu_inference_resend(struct ethosu_mailbox_msg *msg)
if (inf->done)
return 0;
+ /* If marked as ABORTING simply fail it and return */
+ if (inf->status == ETHOSU_UAPI_STATUS_ABORTING) {
+ ethosu_inference_fail(msg);
+
+ return 0;
+ }
+
/* Decrement reference count for pending request */
ret = ethosu_inference_put(inf);
if (ret)
@@ -241,8 +259,22 @@ static long ethosu_inference_ioctl(struct file *file,
break;
}
+ case ETHOSU_IOCTL_INFERENCE_CANCEL: {
+ struct ethosu_uapi_cancel_inference_status uapi;
+
+ dev_info(inf->edev->dev, "Ioctl: Cancel Inference. Handle=%p\n",
+ inf);
+
+ ret = ethosu_cancel_inference_request(inf, &uapi);
+ if (ret)
+ break;
+
+ ret = copy_to_user(udata, &uapi, sizeof(uapi)) ? -EFAULT : 0;
+
+ break;
+ }
default: {
- dev_err(inf->edev->dev, "Invalid ioctl. cmd=%u, arg=%lu",
+ dev_err(inf->edev->dev, "Invalid ioctl. cmd=%u, arg=%lu\n",
cmd, arg);
break;
}
@@ -422,6 +454,8 @@ void ethosu_inference_rsp(struct ethosu_device *edev,
}
} else if (rsp->status == ETHOSU_CORE_STATUS_REJECTED) {
inf->status = ETHOSU_UAPI_STATUS_REJECTED;
+ } else if (rsp->status == ETHOSU_CORE_STATUS_ABORTED) {
+ inf->status = ETHOSU_UAPI_STATUS_ABORTED;
} else {
inf->status = ETHOSU_UAPI_STATUS_ERROR;
}
@@ -450,6 +484,5 @@ void ethosu_inference_rsp(struct ethosu_device *edev,
inf->done = true;
wake_up_interruptible(&inf->waitq);
-
ethosu_inference_put(inf);
}
diff --git a/kernel/ethosu_inference.h b/kernel/ethosu_inference.h
index 66d4ff9..2f188b6 100644
--- a/kernel/ethosu_inference.h
+++ b/kernel/ethosu_inference.h
@@ -30,7 +30,6 @@
#include <linux/kref.h>
#include <linux/types.h>
-#include <linux/wait.h>
/****************************************************************************
* Types
@@ -58,6 +57,7 @@ struct file;
* @pmu_event_count: PMU event count after inference
* @pmu_cycle_counter_enable: PMU cycle counter config
* @pmu_cycle_counter_count: PMU cycle counter count after inference
+ * @msg: Mailbox message
*/
struct ethosu_inference {
struct ethosu_device *edev;
diff --git a/kernel/ethosu_mailbox.c b/kernel/ethosu_mailbox.c
index 7355361..5343e56 100644
--- a/kernel/ethosu_mailbox.c
+++ b/kernel/ethosu_mailbox.c
@@ -418,6 +418,20 @@ int ethosu_mailbox_network_info_request(struct ethosu_mailbox *mbox,
&info, sizeof(info));
}
+int ethosu_mailbox_cancel_inference(struct ethosu_mailbox *mbox,
+ void *user_arg,
+ void *inference_handle)
+{
+ struct ethosu_core_cancel_inference_req req;
+
+ req.user_arg = (ptrdiff_t)user_arg;
+ req.inference_handle = (ptrdiff_t)inference_handle;
+
+ return ethosu_queue_write_msg(mbox,
+ ETHOSU_CORE_MSG_CANCEL_INFERENCE_REQ,
+ &req, sizeof(req));
+}
+
static void ethosu_mailbox_rx_work(struct work_struct *work)
{
struct ethosu_mailbox *mbox = container_of(work, typeof(*mbox), work);
diff --git a/kernel/ethosu_mailbox.h b/kernel/ethosu_mailbox.h
index 55d4436..7af0c47 100644
--- a/kernel/ethosu_mailbox.h
+++ b/kernel/ethosu_mailbox.h
@@ -203,4 +203,13 @@ int ethosu_mailbox_network_info_request(struct ethosu_mailbox *mbox,
struct ethosu_buffer *network,
uint32_t network_index);
+/**
+ * ethosu_mailbox_cancel_inference() - Send inference cancellation
+ *
+ * Return: 0 on success, else error code.
+ */
+int ethosu_mailbox_cancel_inference(struct ethosu_mailbox *mbox,
+ void *user_arg,
+ void *inference_handle);
+
#endif /* ETHOSU_MAILBOX_H */
diff --git a/kernel/uapi/ethosu.h b/kernel/uapi/ethosu.h
index 4627cb9..033afe6 100644
--- a/kernel/uapi/ethosu.h
+++ b/kernel/uapi/ethosu.h
@@ -59,6 +59,8 @@ namespace EthosU {
struct ethosu_uapi_inference_create)
#define ETHOSU_IOCTL_INFERENCE_STATUS ETHOSU_IOR(0x31, \
struct ethosu_uapi_result_status)
+#define ETHOSU_IOCTL_INFERENCE_CANCEL ETHOSU_IOR(0x32, \
+ struct ethosu_uapi_cancel_inference_status)
/* Maximum number of IFM/OFM file descriptors per network */
#define ETHOSU_FD_MAX 16
@@ -78,6 +80,8 @@ enum ethosu_uapi_status {
ETHOSU_UAPI_STATUS_ERROR,
ETHOSU_UAPI_STATUS_RUNNING,
ETHOSU_UAPI_STATUS_REJECTED,
+ ETHOSU_UAPI_STATUS_ABORTED,
+ ETHOSU_UAPI_STATUS_ABORTING,
};
/**
@@ -238,6 +242,14 @@ struct ethosu_uapi_result_status {
struct ethosu_uapi_pmu_counts pmu_count;
};
+/**
+ * struct ethosu_uapi_cancel_status - Status of inference cancellation.
+ * @status OK if inference cancellation was performed, ERROR otherwise.
+ */
+struct ethosu_uapi_cancel_inference_status {
+ enum ethosu_uapi_status status;
+};
+
#ifdef __cplusplus
} /* namespace EthosU */
#endif
diff --git a/utils/inference_runner/inference_runner.cpp b/utils/inference_runner/inference_runner.cpp
index 21e133c..9df67ed 100644
--- a/utils/inference_runner/inference_runner.cpp
+++ b/utils/inference_runner/inference_runner.cpp
@@ -259,22 +259,29 @@ int main(int argc, char *argv[]) {
/* make sure the wait completes ok */
try {
cout << "Wait for inference" << endl;
- inference->wait(timeout);
+ bool timedout = inference->wait(timeout);
+ if (timedout) {
+ cout << "Inference timed out, cancelling it" << endl;
+ bool aborted = inference->cancel();
+ if (!aborted || inference->status() != InferenceStatus::ABORTED) {
+ cout << "Inference cancellation failed" << endl;
+ }
+ }
} catch (std::exception &e) {
- cout << "Failed to wait for inference completion: " << e.what() << endl;
+ cout << "Failed to wait for or to cancel inference: " << e.what() << endl;
exit(1);
}
cout << "Inference status: " << inference->status() << endl;
- string ofmFilename = ofmArg + "." + to_string(ofmIndex);
- ofstream ofmStream(ofmFilename, ios::binary);
- if (!ofmStream.is_open()) {
- cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
- exit(1);
- }
-
if (inference->status() == InferenceStatus::OK) {
+ string ofmFilename = ofmArg + "." + to_string(ofmIndex);
+ ofstream ofmStream(ofmFilename, ios::binary);
+ if (!ofmStream.is_open()) {
+ cerr << "Error: Failed to open '" << ofmFilename << "'" << endl;
+ exit(1);
+ }
+
/* The inference completed and has ok status */
for (auto &ofmBuffer : inference->getOfmBuffers()) {
cout << "OFM size: " << ofmBuffer->size() << endl;
@@ -286,6 +293,8 @@ int main(int argc, char *argv[]) {
ofmStream.write(ofmBuffer->data(), ofmBuffer->size());
}
+ ofmStream.flush();
+
/* Read out PMU counters if configured */
if (std::count(enabledCounters.begin(), enabledCounters.end(), 0) <
Inference::getMaxPmuEventCounters()) {