diff options
Diffstat (limited to 'kernel/ethosu_inference.c')
-rw-r--r-- | kernel/ethosu_inference.c | 70 |
1 files changed, 48 insertions, 22 deletions
diff --git a/kernel/ethosu_inference.c b/kernel/ethosu_inference.c index 8efc22d..e9530cf 100644 --- a/kernel/ethosu_inference.c +++ b/kernel/ethosu_inference.c @@ -86,8 +86,10 @@ static int ethosu_inference_send(struct ethosu_inference *inf) inf->status = ETHOSU_UAPI_STATUS_ERROR; - ret = ethosu_mailbox_inference(&inf->edev->mailbox, inf, inf->ifm, - inf->ofm, inf->net->buf); + ret = ethosu_mailbox_inference(&inf->edev->mailbox, inf, + inf->ifm_count, inf->ifm, + inf->ofm_count, inf->ofm, + inf->net->buf); if (ret) return ret; @@ -126,8 +128,13 @@ static void ethosu_inference_kref_destroy(struct kref *kref) inf, inf->status); list_del(&inf->list); - ethosu_buffer_put(inf->ifm); - ethosu_buffer_put(inf->ofm); + + while (inf->ifm_count-- > 0) + ethosu_buffer_put(inf->ifm[inf->ifm_count]); + + while (inf->ofm_count-- > 0) + ethosu_buffer_put(inf->ofm[inf->ofm_count]); + ethosu_network_put(inf->net); devm_kfree(inf->edev->dev, inf); } @@ -199,6 +206,7 @@ int ethosu_inference_create(struct ethosu_device *edev, struct ethosu_uapi_inference_create *uapi) { struct ethosu_inference *inf; + uint32_t i; int fd; int ret = -ENOMEM; @@ -213,18 +221,26 @@ int ethosu_inference_create(struct ethosu_device *edev, kref_init(&inf->kref); init_waitqueue_head(&inf->waitq); - /* Get pointer to IFM buffer */ - inf->ifm = ethosu_buffer_get_from_fd(uapi->ifm_fd); - if (IS_ERR(inf->ifm)) { - ret = PTR_ERR(inf->ifm); - goto free_inf; + /* Get pointer to IFM buffers */ + for (i = 0; i < uapi->ifm_count; i++) { + inf->ifm[i] = ethosu_buffer_get_from_fd(uapi->ifm_fd[i]); + if (IS_ERR(inf->ifm[i])) { + ret = PTR_ERR(inf->ifm[i]); + goto put_ifm; + } + + inf->ifm_count++; } /* Get pointer to OFM buffer */ - inf->ofm = ethosu_buffer_get_from_fd(uapi->ofm_fd); - if (IS_ERR(inf->ofm)) { - ret = PTR_ERR(inf->ofm); - goto put_ifm; + for (i = 0; i < uapi->ofm_count; i++) { + inf->ofm[i] = ethosu_buffer_get_from_fd(uapi->ofm_fd[i]); + if (IS_ERR(inf->ofm[i])) { + ret = PTR_ERR(inf->ofm[i]); + goto put_ofm; + } + + inf->ofm_count++; } /* Increment network reference count */ @@ -253,12 +269,15 @@ int ethosu_inference_create(struct ethosu_device *edev, put_net: ethosu_network_put(inf->net); - ethosu_buffer_put(inf->ofm); + +put_ofm: + while (inf->ofm_count-- > 0) + ethosu_buffer_put(inf->ofm[inf->ofm_count]); put_ifm: - ethosu_buffer_put(inf->ifm); + while (inf->ifm_count-- > 0) + ethosu_buffer_put(inf->ifm[inf->ifm_count]); -free_inf: devm_kfree(edev->dev, inf); return ret; @@ -314,14 +333,21 @@ void ethosu_inference_rsp(struct ethosu_device *edev, inf->pending = false; - if (rsp->status == ETHOSU_CORE_STATUS_OK) { + if (rsp->status == ETHOSU_CORE_STATUS_OK && + inf->ofm_count <= ETHOSU_CORE_BUFFER_MAX) { + uint32_t i; + inf->status = ETHOSU_UAPI_STATUS_OK; - ret = ethosu_buffer_resize(inf->ofm, - inf->ofm->size + rsp->ofm_size, - inf->ofm->offset); - if (ret) - inf->status = ETHOSU_UAPI_STATUS_ERROR; + for (i = 0; i < inf->ofm_count; i++) { + struct ethosu_buffer *ofm = inf->ofm[i]; + + ret = ethosu_buffer_resize( + ofm, ofm->size + rsp->ofm_size[i], + ofm->offset); + if (ret) + inf->status = ETHOSU_UAPI_STATUS_ERROR; + } } else { inf->status = ETHOSU_UAPI_STATUS_ERROR; } |