aboutsummaryrefslogtreecommitdiff
path: root/kernel/ethosu_inference.c
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/ethosu_inference.c')
-rw-r--r--kernel/ethosu_inference.c70
1 files changed, 48 insertions, 22 deletions
diff --git a/kernel/ethosu_inference.c b/kernel/ethosu_inference.c
index 8efc22d..e9530cf 100644
--- a/kernel/ethosu_inference.c
+++ b/kernel/ethosu_inference.c
@@ -86,8 +86,10 @@ static int ethosu_inference_send(struct ethosu_inference *inf)
inf->status = ETHOSU_UAPI_STATUS_ERROR;
- ret = ethosu_mailbox_inference(&inf->edev->mailbox, inf, inf->ifm,
- inf->ofm, inf->net->buf);
+ ret = ethosu_mailbox_inference(&inf->edev->mailbox, inf,
+ inf->ifm_count, inf->ifm,
+ inf->ofm_count, inf->ofm,
+ inf->net->buf);
if (ret)
return ret;
@@ -126,8 +128,13 @@ static void ethosu_inference_kref_destroy(struct kref *kref)
inf, inf->status);
list_del(&inf->list);
- ethosu_buffer_put(inf->ifm);
- ethosu_buffer_put(inf->ofm);
+
+ while (inf->ifm_count-- > 0)
+ ethosu_buffer_put(inf->ifm[inf->ifm_count]);
+
+ while (inf->ofm_count-- > 0)
+ ethosu_buffer_put(inf->ofm[inf->ofm_count]);
+
ethosu_network_put(inf->net);
devm_kfree(inf->edev->dev, inf);
}
@@ -199,6 +206,7 @@ int ethosu_inference_create(struct ethosu_device *edev,
struct ethosu_uapi_inference_create *uapi)
{
struct ethosu_inference *inf;
+ uint32_t i;
int fd;
int ret = -ENOMEM;
@@ -213,18 +221,26 @@ int ethosu_inference_create(struct ethosu_device *edev,
kref_init(&inf->kref);
init_waitqueue_head(&inf->waitq);
- /* Get pointer to IFM buffer */
- inf->ifm = ethosu_buffer_get_from_fd(uapi->ifm_fd);
- if (IS_ERR(inf->ifm)) {
- ret = PTR_ERR(inf->ifm);
- goto free_inf;
+ /* Get pointer to IFM buffers */
+ for (i = 0; i < uapi->ifm_count; i++) {
+ inf->ifm[i] = ethosu_buffer_get_from_fd(uapi->ifm_fd[i]);
+ if (IS_ERR(inf->ifm[i])) {
+ ret = PTR_ERR(inf->ifm[i]);
+ goto put_ifm;
+ }
+
+ inf->ifm_count++;
}
/* Get pointer to OFM buffer */
- inf->ofm = ethosu_buffer_get_from_fd(uapi->ofm_fd);
- if (IS_ERR(inf->ofm)) {
- ret = PTR_ERR(inf->ofm);
- goto put_ifm;
+ for (i = 0; i < uapi->ofm_count; i++) {
+ inf->ofm[i] = ethosu_buffer_get_from_fd(uapi->ofm_fd[i]);
+ if (IS_ERR(inf->ofm[i])) {
+ ret = PTR_ERR(inf->ofm[i]);
+ goto put_ofm;
+ }
+
+ inf->ofm_count++;
}
/* Increment network reference count */
@@ -253,12 +269,15 @@ int ethosu_inference_create(struct ethosu_device *edev,
put_net:
ethosu_network_put(inf->net);
- ethosu_buffer_put(inf->ofm);
+
+put_ofm:
+ while (inf->ofm_count-- > 0)
+ ethosu_buffer_put(inf->ofm[inf->ofm_count]);
put_ifm:
- ethosu_buffer_put(inf->ifm);
+ while (inf->ifm_count-- > 0)
+ ethosu_buffer_put(inf->ifm[inf->ifm_count]);
-free_inf:
devm_kfree(edev->dev, inf);
return ret;
@@ -314,14 +333,21 @@ void ethosu_inference_rsp(struct ethosu_device *edev,
inf->pending = false;
- if (rsp->status == ETHOSU_CORE_STATUS_OK) {
+ if (rsp->status == ETHOSU_CORE_STATUS_OK &&
+ inf->ofm_count <= ETHOSU_CORE_BUFFER_MAX) {
+ uint32_t i;
+
inf->status = ETHOSU_UAPI_STATUS_OK;
- ret = ethosu_buffer_resize(inf->ofm,
- inf->ofm->size + rsp->ofm_size,
- inf->ofm->offset);
- if (ret)
- inf->status = ETHOSU_UAPI_STATUS_ERROR;
+ for (i = 0; i < inf->ofm_count; i++) {
+ struct ethosu_buffer *ofm = inf->ofm[i];
+
+ ret = ethosu_buffer_resize(
+ ofm, ofm->size + rsp->ofm_size[i],
+ ofm->offset);
+ if (ret)
+ inf->status = ETHOSU_UAPI_STATUS_ERROR;
+ }
} else {
inf->status = ETHOSU_UAPI_STATUS_ERROR;
}