aboutsummaryrefslogtreecommitdiff
path: root/ArmnnPreparedModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ArmnnPreparedModel.cpp')
-rw-r--r--ArmnnPreparedModel.cpp55
1 files changed, 29 insertions, 26 deletions
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp
index 0899430c..2cd560d7 100644
--- a/ArmnnPreparedModel.cpp
+++ b/ArmnnPreparedModel.cpp
@@ -11,12 +11,8 @@
#include <boost/format.hpp>
#include <log/log.h>
#include <OperationsUtils.h>
-
-#if defined(ARMNN_ANDROID_P) || defined(ARMNN_ANDROID_Q)
-// The headers of the ML framework have changed between Android O and Android P.
-// The validation functions have been moved into their own header, ValidateHal.h.
#include <ValidateHal.h>
-#endif
+
#include <cassert>
#include <cinttypes>
@@ -27,7 +23,7 @@ namespace
{
using namespace armnn_driver;
-void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback, ErrorStatus errorStatus,
+void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback, V1_0::ErrorStatus errorStatus,
std::string callingFunction)
{
Return<void> returned = callback->notify(errorStatus);
@@ -139,21 +135,22 @@ ArmnnPreparedModel<HalVersion>::~ArmnnPreparedModel()
}
template<typename HalVersion>
-Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& request,
- const ::android::sp<V1_0::IExecutionCallback>& callback)
+Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(
+ const V1_0::Request& request,
+ const ::android::sp<V1_0::IExecutionCallback>& callback)
{
ALOGV("ArmnnPreparedModel::execute(): %s", GetModelSummary(m_Model).c_str());
m_RequestCount++;
if (callback.get() == nullptr) {
ALOGE("ArmnnPreparedModel::execute invalid callback passed");
- return ErrorStatus::INVALID_ARGUMENT;
+ return V1_0::ErrorStatus::INVALID_ARGUMENT;
}
if (!android::nn::validateRequest(request, m_Model))
{
- NotifyCallbackAndCheck(callback, ErrorStatus::INVALID_ARGUMENT, "ArmnnPreparedModel::execute");
- return ErrorStatus::INVALID_ARGUMENT;
+ NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::INVALID_ARGUMENT, "ArmnnPreparedModel::execute");
+ return V1_0::ErrorStatus::INVALID_ARGUMENT;
}
if (!m_RequestInputsAndOutputsDumpDir.empty())
@@ -170,8 +167,8 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque
auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools))
{
- NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
- return ErrorStatus::GENERAL_FAILURE;
+ NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
// add the inputs and outputs with their data
@@ -187,7 +184,7 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque
if (inputTensor.GetMemoryArea() == nullptr)
{
ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
- return ErrorStatus::GENERAL_FAILURE;
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
pInputTensors->emplace_back(i, inputTensor);
@@ -203,7 +200,7 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque
if (outputTensor.GetMemoryArea() == nullptr)
{
ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
- return ErrorStatus::GENERAL_FAILURE;
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
pOutputTensors->emplace_back(i, outputTensor);
@@ -212,19 +209,19 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque
catch (armnn::Exception& e)
{
ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
- NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
- return ErrorStatus::GENERAL_FAILURE;
+ NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
catch (std::exception& e)
{
ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
- NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
- return ErrorStatus::GENERAL_FAILURE;
+ NotifyCallbackAndCheck(callback, V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
- auto cb = [callback](ErrorStatus errorStatus, std::string callingFunction)
+ auto cb = [callback](V1_0::ErrorStatus errorStatus, std::string callingFunction)
{
NotifyCallbackAndCheck(callback, errorStatus, callingFunction);
};
@@ -234,7 +231,7 @@ Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& reque
// post the request for asynchronous execution
m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
ALOGV("ArmnnPreparedModel::execute(...) after PostMsg");
- return ErrorStatus::NONE; // successfully queued
+ return V1_0::ErrorStatus::NONE; // successfully queued
}
template<typename HalVersion>
@@ -255,20 +252,20 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph(
if (status != armnn::Status::Success)
{
ALOGW("EnqueueWorkload failed");
- cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
+ cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
return;
}
}
catch (armnn::Exception& e)
{
ALOGW("armnn::Exception caught from EnqueueWorkload: %s", e.what());
- cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
+ cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
return;
}
catch (std::exception& e)
{
ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
- cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
+ cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
return;
}
@@ -279,10 +276,16 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph(
// this is simpler and is what the CpuExecutor does.
for (android::nn::RunTimePoolInfo& pool : *pMemPools)
{
- pool.update();
+ // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
+ // update() has been removed and flush() added.
+ #if defined(ARMNN_ANDROID_R) // Use the new Android implementation.
+ pool.flush();
+ #else
+ pool.update();
+ #endif
}
- cb.callback(ErrorStatus::NONE, "ExecuteGraph");
+ cb.callback(V1_0::ErrorStatus::NONE, "ExecuteGraph");
}
template<typename HalVersion>