diff options
Diffstat (limited to 'ArmnnPreparedModel.cpp')
-rw-r--r-- | ArmnnPreparedModel.cpp | 55 |
1 files changed, 29 insertions, 26 deletions
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp index 0899430c..2cd560d7 100644 --- a/ArmnnPreparedModel.cpp +++ b/ArmnnPreparedModel.cpp @@ -11,12 +11,8 @@ #include <boost/format.hpp> #include <log/log.h> #include <OperationsUtils.h> - -#if defined(ARMNN_ANDROID_P) || defined(ARMNN_ANDROID_Q) -// The headers of the ML framework have changed between Android O and Android P. -// The validation functions have been moved into their own header, ValidateHal.h. #include <ValidateHal.h> -#endif + #include <cassert> #include <cinttypes> @@ -27,7 +23,7 @@ namespace { using namespace armnn_driver; -void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback, ErrorStatus errorStatus, +void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback, V1_0::ErrorStatus errorStatus, std::string callingFunction) { Return<void> returned = callback->notify(errorStatus); @@ -139,21 +135,22 @@ ArmnnPreparedModel<HalVersion>::~ArmnnPreparedModel() } template<typename HalVersion> -Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& request, - const ::android::sp<V1_0::IExecutionCallback>& callback) +Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute( + const V1_0::Request& request, + const ::android::sp<V1_0::IExecutionCallback>& callback) { ALOGV("ArmnnPreparedModel::execute(): %s", GetModelSummary(m_Model).c_str()); m_RequestCount++; if (callback.get() == nullptr) { ALOGE("ArmnnPreparedModel::execute invalid callback passed"); - return ErrorStatus::INVALID_ARGUMENT; + return V1_0::ErrorStatus::INVALID_ARGUMENT; } if (!android::nn::validateRequest(request, m_Model)) { - NotifyCallbackAndCheck(callback, ErrorStatus::INVALID_ARGUMENT, "ArmnnPreparedModel::execute"); - return ErrorStatus::INVALID_ARGUMENT; + NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::INVALID_ARGUMENT, "ArmnnPreparedModel::execute"); + return V1_0::ErrorStatus::INVALID_ARGUMENT; } if (!m_RequestInputsAndOutputsDumpDir.empty()) @@ -170,8 +167,8 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>(); if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools)) { - NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute"); - return ErrorStatus::GENERAL_FAILURE; + NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute"); + return V1_0::ErrorStatus::GENERAL_FAILURE; } // add the inputs and outputs with their data @@ -187,7 +184,7 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque if (inputTensor.GetMemoryArea() == nullptr) { ALOGE("Cannot execute request. Error converting request input %u to tensor", i); - return ErrorStatus::GENERAL_FAILURE; + return V1_0::ErrorStatus::GENERAL_FAILURE; } pInputTensors->emplace_back(i, inputTensor); @@ -203,7 +200,7 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque if (outputTensor.GetMemoryArea() == nullptr) { ALOGE("Cannot execute request. Error converting request output %u to tensor", i); - return ErrorStatus::GENERAL_FAILURE; + return V1_0::ErrorStatus::GENERAL_FAILURE; } pOutputTensors->emplace_back(i, outputTensor); @@ -212,19 +209,19 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque catch (armnn::Exception& e) { ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what()); - NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute"); - return ErrorStatus::GENERAL_FAILURE; + NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute"); + return V1_0::ErrorStatus::GENERAL_FAILURE; } catch (std::exception& e) { ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what()); - NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute"); - return ErrorStatus::GENERAL_FAILURE; + NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute"); + return V1_0::ErrorStatus::GENERAL_FAILURE; } ALOGV("ArmnnPreparedModel::execute(...) before PostMsg"); - auto cb = [callback](ErrorStatus errorStatus, std::string callingFunction) + auto cb = [callback](V1_0::ErrorStatus errorStatus, std::string callingFunction) { NotifyCallbackAndCheck(callback, errorStatus, callingFunction); }; @@ -234,7 +231,7 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque // post the request for asynchronous execution m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb); ALOGV("ArmnnPreparedModel::execute(...) after PostMsg"); - return ErrorStatus::NONE; // successfully queued + return V1_0::ErrorStatus::NONE; // successfully queued } template<typename HalVersion> @@ -255,20 +252,20 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph( if (status != armnn::Status::Success) { ALOGW("EnqueueWorkload failed"); - cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); + cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); return; } } catch (armnn::Exception& e) { ALOGW("armnn::Exception caught from EnqueueWorkload: %s", e.what()); - cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); + cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); return; } catch (std::exception& e) { ALOGE("std::exception caught from EnqueueWorkload: %s", e.what()); - cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); + cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph"); return; } @@ -279,10 +276,16 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph( // this is simpler and is what the CpuExecutor does. for (android::nn::RunTimePoolInfo& pool : *pMemPools) { - pool.update(); + // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where + // update() has been removed and flush() added. + #if defined(ARMNN_ANDROID_R) // Use the new Android implementation. + pool.flush(); + #else + pool.update(); + #endif } - cb.callback(ErrorStatus::NONE, "ExecuteGraph"); + cb.callback(V1_0::ErrorStatus::NONE, "ExecuteGraph"); } template<typename HalVersion> |