diff options
Diffstat (limited to 'ArmnnPreparedModel.cpp')
-rw-r--r-- | ArmnnPreparedModel.cpp | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp index 3256836e..462970aa 100644 --- a/ArmnnPreparedModel.cpp +++ b/ArmnnPreparedModel.cpp @@ -87,9 +87,8 @@ using namespace android::hardware; namespace armnn_driver { - template<typename HalVersion> -RequestThread<ArmnnPreparedModel, HalVersion> ArmnnPreparedModel<HalVersion>::m_RequestThread; +RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> ArmnnPreparedModel<HalVersion>::m_RequestThread; template<typename HalVersion> template <typename TensorBindingCollection> @@ -218,10 +217,17 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque } ALOGV("ArmnnPreparedModel::execute(...) before PostMsg"); + + auto cb = [callback](ErrorStatus errorStatus, std::string callingFunction) + { + NotifyCallbackAndCheck(callback, errorStatus, callingFunction); + }; + + ArmnnCallback_1_0 armnnCb; + armnnCb.callback = cb; // post the request for asynchronous execution - m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, callback); + m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb); ALOGV("ArmnnPreparedModel::execute(...) after PostMsg"); - return ErrorStatus::NONE; // successfully queued } @@ -230,7 +236,7 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph( std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, std::shared_ptr<armnn::InputTensors>& pInputTensors, std::shared_ptr<armnn::OutputTensors>& pOutputTensors, - const ::android::sp<V1_0::IExecutionCallback>& callback) + ArmnnCallback_1_0 cb) { ALOGV("ArmnnPreparedModel::ExecuteGraph(...)"); @@ -243,14 +249,14 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph( if (status != armnn::Status::Success) { ALOGW("EnqueueWorkload failed"); - NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); + cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); return; } } catch (armnn::Exception& e) { ALOGW("armnn::Exception caught from EnqueueWorkload: %s", e.what()); - NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); + cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); return; } @@ -264,7 +270,7 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph( pool.update(); } - NotifyCallbackAndCheck(callback, ErrorStatus::NONE, "ExecuteGraph"); + cb.callback(ErrorStatus::NONE, "ExecuteGraph"); } template<typename HalVersion> |