aboutsummaryrefslogtreecommitdiff
path: root/kernel/ethosu_inference.c
diff options
context:
space:
mode:
authorKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-10 13:26:01 +0200
committerKristofer Jonsson <kristofer.jonsson@arm.com>2020-09-17 13:23:27 +0200
commitb74492c5aee3786b886951e87f4e5ea8d6032733 (patch)
tree76ef44dfdb68d68964877b0adba21cbce2416fe5 /kernel/ethosu_inference.c
parent116a635581f292cb4882ea1a086f842904f85c3c (diff)
downloadethos-u-linux-driver-stack-b74492c5aee3786b886951e87f4e5ea8d6032733.tar.gz
Support inferences with multiple inputs and outputs
Build flatbuffers library. Update network class to extract IFM and OFM dimensions from the tflite file. Update the uapi and core apis to support up to 16 IFM and OFM buffers per inference. Change-Id: I2f2f177aa4c2d5f9f50f23eb33c44e01ec2cbe09
Diffstat (limited to 'kernel/ethosu_inference.c')
-rw-r--r--kernel/ethosu_inference.c70
1 files changed, 48 insertions, 22 deletions
diff --git a/kernel/ethosu_inference.c b/kernel/ethosu_inference.c
index 8efc22d..e9530cf 100644
--- a/kernel/ethosu_inference.c
+++ b/kernel/ethosu_inference.c
@@ -86,8 +86,10 @@ static int ethosu_inference_send(struct ethosu_inference *inf)
inf->status = ETHOSU_UAPI_STATUS_ERROR;
- ret = ethosu_mailbox_inference(&inf->edev->mailbox, inf, inf->ifm,
- inf->ofm, inf->net->buf);
+ ret = ethosu_mailbox_inference(&inf->edev->mailbox, inf,
+ inf->ifm_count, inf->ifm,
+ inf->ofm_count, inf->ofm,
+ inf->net->buf);
if (ret)
return ret;
@@ -126,8 +128,13 @@ static void ethosu_inference_kref_destroy(struct kref *kref)
inf, inf->status);
list_del(&inf->list);
- ethosu_buffer_put(inf->ifm);
- ethosu_buffer_put(inf->ofm);
+
+ while (inf->ifm_count-- > 0)
+ ethosu_buffer_put(inf->ifm[inf->ifm_count]);
+
+ while (inf->ofm_count-- > 0)
+ ethosu_buffer_put(inf->ofm[inf->ofm_count]);
+
ethosu_network_put(inf->net);
devm_kfree(inf->edev->dev, inf);
}
@@ -199,6 +206,7 @@ int ethosu_inference_create(struct ethosu_device *edev,
struct ethosu_uapi_inference_create *uapi)
{
struct ethosu_inference *inf;
+ uint32_t i;
int fd;
int ret = -ENOMEM;
@@ -213,18 +221,26 @@ int ethosu_inference_create(struct ethosu_device *edev,
kref_init(&inf->kref);
init_waitqueue_head(&inf->waitq);
- /* Get pointer to IFM buffer */
- inf->ifm = ethosu_buffer_get_from_fd(uapi->ifm_fd);
- if (IS_ERR(inf->ifm)) {
- ret = PTR_ERR(inf->ifm);
- goto free_inf;
+ /* Get pointer to IFM buffers */
+ for (i = 0; i < uapi->ifm_count; i++) {
+ inf->ifm[i] = ethosu_buffer_get_from_fd(uapi->ifm_fd[i]);
+ if (IS_ERR(inf->ifm[i])) {
+ ret = PTR_ERR(inf->ifm[i]);
+ goto put_ifm;
+ }
+
+ inf->ifm_count++;
}
/* Get pointer to OFM buffer */
- inf->ofm = ethosu_buffer_get_from_fd(uapi->ofm_fd);
- if (IS_ERR(inf->ofm)) {
- ret = PTR_ERR(inf->ofm);
- goto put_ifm;
+ for (i = 0; i < uapi->ofm_count; i++) {
+ inf->ofm[i] = ethosu_buffer_get_from_fd(uapi->ofm_fd[i]);
+ if (IS_ERR(inf->ofm[i])) {
+ ret = PTR_ERR(inf->ofm[i]);
+ goto put_ofm;
+ }
+
+ inf->ofm_count++;
}
/* Increment network reference count */
@@ -253,12 +269,15 @@ int ethosu_inference_create(struct ethosu_device *edev,
put_net:
ethosu_network_put(inf->net);
- ethosu_buffer_put(inf->ofm);
+
+put_ofm:
+ while (inf->ofm_count-- > 0)
+ ethosu_buffer_put(inf->ofm[inf->ofm_count]);
put_ifm:
- ethosu_buffer_put(inf->ifm);
+ while (inf->ifm_count-- > 0)
+ ethosu_buffer_put(inf->ifm[inf->ifm_count]);
-free_inf:
devm_kfree(edev->dev, inf);
return ret;
@@ -314,14 +333,21 @@ void ethosu_inference_rsp(struct ethosu_device *edev,
inf->pending = false;
- if (rsp->status == ETHOSU_CORE_STATUS_OK) {
+ if (rsp->status == ETHOSU_CORE_STATUS_OK &&
+ inf->ofm_count <= ETHOSU_CORE_BUFFER_MAX) {
+ uint32_t i;
+
inf->status = ETHOSU_UAPI_STATUS_OK;
- ret = ethosu_buffer_resize(inf->ofm,
- inf->ofm->size + rsp->ofm_size,
- inf->ofm->offset);
- if (ret)
- inf->status = ETHOSU_UAPI_STATUS_ERROR;
+ for (i = 0; i < inf->ofm_count; i++) {
+ struct ethosu_buffer *ofm = inf->ofm[i];
+
+ ret = ethosu_buffer_resize(
+ ofm, ofm->size + rsp->ofm_size[i],
+ ofm->offset);
+ if (ret)
+ inf->status = ETHOSU_UAPI_STATUS_ERROR;
+ }
} else {
inf->status = ETHOSU_UAPI_STATUS_ERROR;
}