aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2020-03-17 13:40:18 +0000
committerDerek Lamberti <derek.lamberti@arm.com>2020-03-18 00:04:11 +0000
commit4de83c5a6a57d0468d9f2f854c94bc4a760b66b6 (patch)
tree090105cdfeba4e56a46f5f06aa6c032caf1db397
parent0b7a419c156e6f9eaf36ec166d4a3e0878b16959 (diff)
downloadandroid-nn-driver-4de83c5a6a57d0468d9f2f854c94bc4a760b66b6.tar.gz
Less code duplication in HAL 1.2
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com> Change-Id: Ic2e8964745a4323efb1e06d466c0699f17a70c55
-rw-r--r--ArmnnDriverImpl.hpp7
-rw-r--r--ArmnnPreparedModel.cpp17
-rw-r--r--ArmnnPreparedModel.hpp12
-rw-r--r--ArmnnPreparedModel_1_2.cpp649
-rw-r--r--ArmnnPreparedModel_1_2.hpp54
-rw-r--r--RequestThread.cpp40
-rw-r--r--RequestThread.hpp11
7 files changed, 377 insertions, 413 deletions
diff --git a/ArmnnDriverImpl.hpp b/ArmnnDriverImpl.hpp
index c5b1778..dfaafb3 100644
--- a/ArmnnDriverImpl.hpp
+++ b/ArmnnDriverImpl.hpp
@@ -23,6 +23,13 @@ namespace V1_2 = ::android::hardware::neuralnetworks::V1_2;
namespace armnn_driver
{
+template <typename Callback, typename Context>
+struct CallbackContext
+{
+ Callback callback;
+ Context ctx;
+};
+
template<typename HalPolicy>
class ArmnnDriverImpl
{
diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp
index 2cd560d..d095e41 100644
--- a/ArmnnPreparedModel.cpp
+++ b/ArmnnPreparedModel.cpp
@@ -84,7 +84,8 @@ using namespace android::hardware;
namespace armnn_driver
{
template<typename HalVersion>
-RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> ArmnnPreparedModel<HalVersion>::m_RequestThread;
+RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0>
+ ArmnnPreparedModel<HalVersion>::m_RequestThread;
template<typename HalVersion>
template <typename TensorBindingCollection>
@@ -226,7 +227,7 @@ Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(
NotifyCallbackAndCheck(callback, errorStatus, callingFunction);
};
- ArmnnCallback_1_0 armnnCb;
+ CallbackContext_1_0 armnnCb;
armnnCb.callback = cb;
// post the request for asynchronous execution
m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
@@ -237,18 +238,18 @@ Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(
template<typename HalVersion>
void ArmnnPreparedModel<HalVersion>::ExecuteGraph(
std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
- std::shared_ptr<armnn::InputTensors>& pInputTensors,
- std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
- ArmnnCallback_1_0 cb)
+ armnn::InputTensors& inputTensors,
+ armnn::OutputTensors& outputTensors,
+ CallbackContext_1_0 cb)
{
ALOGV("ArmnnPreparedModel::ExecuteGraph(...)");
- DumpTensorsIfRequired("Input", *pInputTensors);
+ DumpTensorsIfRequired("Input", inputTensors);
// run it
try
{
- armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
+ armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
if (status != armnn::Status::Success)
{
ALOGW("EnqueueWorkload failed");
@@ -269,7 +270,7 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph(
return;
}
- DumpTensorsIfRequired("Output", *pOutputTensors);
+ DumpTensorsIfRequired("Output", outputTensors);
// Commit output buffers.
// Note that we update *all* pools, even if they aren't actually used as outputs -
diff --git a/ArmnnPreparedModel.hpp b/ArmnnPreparedModel.hpp
index 270a933..89f6226 100644
--- a/ArmnnPreparedModel.hpp
+++ b/ArmnnPreparedModel.hpp
@@ -24,6 +24,10 @@ struct ArmnnCallback_1_0
armnnExecuteCallback_1_0 callback;
};
+struct ExecutionContext_1_0 {};
+
+using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
+
template <typename HalVersion>
class ArmnnPreparedModel : public V1_0::IPreparedModel
{
@@ -43,9 +47,9 @@ public:
/// execute the graph prepared from the request
void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
- std::shared_ptr<armnn::InputTensors>& pInputTensors,
- std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
- ArmnnCallback_1_0 callback);
+ armnn::InputTensors& inputTensors,
+ armnn::OutputTensors& outputTensors,
+ CallbackContext_1_0 callback);
/// Executes this model with dummy inputs (e.g. all zeroes).
/// \return false on failure, otherwise true
@@ -60,7 +64,7 @@ private:
HalModel m_Model;
// There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
// It is specific to this class, so it is declared as static here
- static RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> m_RequestThread;
+ static RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> m_RequestThread;
uint32_t m_RequestCount;
const std::string& m_RequestInputsAndOutputsDumpDir;
const bool m_GpuProfilingEnabled;
diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp
index 9b79044..5031c5f 100644
--- a/ArmnnPreparedModel_1_2.cpp
+++ b/ArmnnPreparedModel_1_2.cpp
@@ -120,7 +120,7 @@ namespace armnn_driver
{
template<typename HalVersion>
-RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2>
+RequestThread<ArmnnPreparedModel_1_2, HalVersion, CallbackContext_1_2>
ArmnnPreparedModel_1_2<HalVersion>::m_RequestThread;
template<typename HalVersion>
@@ -215,339 +215,256 @@ Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute_1_2(
return Execute(request, measureTiming, cb);
}
-template<typename HalVersion>
-Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request,
- MeasureTiming measureTiming,
- executeSynchronously_cb cb)
+OutputShape ComputeShape(const armnn::TensorInfo& info)
{
- ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
- m_RequestCount++;
-
- if (cb == nullptr)
- {
- ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid callback passed");
- return Void();
- }
+ OutputShape shape;
- TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
+ hidl_vec<uint32_t> dimensions;
- if (measureTiming == MeasureTiming::YES)
- {
- driverStart = Now();
- }
+ armnn::TensorShape tensorShape = info.GetShape();
+ const unsigned int numDims = tensorShape.GetNumDimensions();
+ dimensions.resize(numDims);
- if (!android::nn::validateRequest(request, m_Model))
+ for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
{
- ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model");
- cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming);
- return Void();
+ dimensions[outputIdx] = tensorShape[outputIdx];
}
- // allocate the tensors on the heap, as they are passed to the request thread
- auto pInputTensors = std::make_shared<armnn::InputTensors>();
- auto pOutputTensors = std::make_shared<armnn::OutputTensors>();
-
- // map the memory pool into shared pointers
- // use a shared memory pools vector on the heap, as it is passed to the request thread
- auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
+ shape.dimensions = dimensions;
+ shape.isSufficient = true;
- if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools))
- {
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
- }
- std::vector<OutputShape> outputShapes(request.outputs.size());
+ return shape;
+}
- try
+template<typename HalVersion>
+Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForInputs(
+ armnn::InputTensors& inputs,
+ const V1_0::Request& request,
+ const std::vector<android::nn::RunTimePoolInfo>& memPools)
+{
+ inputs.reserve(request.inputs.size());
+ for (unsigned int i = 0; i < request.inputs.size(); i++)
{
- pInputTensors->reserve(request.inputs.size());
- for (unsigned int i = 0; i < request.inputs.size(); i++)
- {
- const auto& inputArg = request.inputs[i];
-
- const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
- const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, *pMemPools);
+ const auto& inputArg = request.inputs[i];
- if (inputTensor.GetMemoryArea() == nullptr)
- {
- ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
- }
+ const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
+ const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, memPools);
- pInputTensors->emplace_back(i, inputTensor);
+ if (inputTensor.GetMemoryArea() == nullptr)
+ {
+ ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- pOutputTensors->reserve(request.outputs.size());
- for (unsigned int i = 0; i < request.outputs.size(); i++)
- {
- const auto& outputArg = request.outputs[i];
+ inputs.emplace_back(i, inputTensor);
+ }
- const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
- const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, *pMemPools);
+ return V1_0::ErrorStatus::NONE;
+}
- if (outputTensor.GetMemoryArea() == nullptr)
- {
- ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
- }
- const size_t outputSize = outputTensorInfo.GetNumBytes();
- const size_t bufferSize = pMemPools->at(outputArg.location.poolIndex).getHidlMemory().size();
+template<typename HalVersion>
+Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForOutputs(
+ armnn::OutputTensors& outputs,
+ std::vector<OutputShape> &outputShapes,
+ const V1_0::Request& request,
+ const std::vector<android::nn::RunTimePoolInfo>& memPools)
+{
+ outputs.reserve(request.outputs.size());
+ for (unsigned int i = 0; i < request.outputs.size(); i++)
+ {
+ const auto& outputArg = request.outputs[i];
- hidl_vec<uint32_t> dimensions;
+ const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
+ const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, memPools);
+ if (outputTensor.GetMemoryArea() == nullptr)
+ {
+ ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ }
- armnn::TensorShape tensorShape = outputTensorInfo.GetShape();
- const unsigned int numDims = tensorShape.GetNumDimensions();
- dimensions.resize(numDims);
+ const size_t outputSize = outputTensorInfo.GetNumBytes();
+ const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
+ if (bufferSize < outputSize)
+ {
+ ALOGW("ArmnnPreparedModel_1_2::Execute failed");
+ return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
+ }
- for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
- {
- dimensions[outputIdx] = tensorShape[outputIdx];
- }
- outputShapes[i].dimensions = dimensions;
- outputShapes[i].isSufficient = bufferSize >= outputSize;
+ outputs.emplace_back(i, outputTensor);
+ outputShapes[i] = ComputeShape(outputTensorInfo);
+ }
- if (bufferSize < outputSize)
- {
- ALOGW("ArmnnPreparedModel_1_2::Execute failed");
- cb(V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, outputShapes, g_NoTiming);
- return Void();
- }
+ return V1_0::ErrorStatus::NONE;
+}
- pOutputTensors->emplace_back(i, outputTensor);
- }
- }
- catch (armnn::Exception& e)
- {
- ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
- }
- catch (std::exception& e)
+template<typename HalVersion>
+Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForIO(
+ armnn::InputTensors& inputs,
+ armnn::OutputTensors& outputs,
+ std::vector<android::nn::RunTimePoolInfo>& memPools,
+ const V1_0::Request& request,
+ CallbackAsync_1_2 callback)
+{
+ if (!setRunTimePoolInfosFromHidlMemories(&memPools, request.pools))
{
- ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
+ callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution");
-
- DumpTensorsIfRequired("Input", *pInputTensors);
- // run it
+ // add the inputs and outputs with their data
try
{
- if (measureTiming == MeasureTiming::YES)
+ if (PrepareMemoryForInputs(inputs, request, memPools) != V1_0::ErrorStatus::NONE)
{
- deviceStart = Now();
+ callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
-
- if (measureTiming == MeasureTiming::YES)
- {
- deviceEnd = Now();
- }
+ std::vector<OutputShape> outputShapes(request.outputs.size());
- if (status != armnn::Status::Success)
+ auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
+ if (errorStatus != V1_0::ErrorStatus::NONE)
{
- ALOGW("EnqueueWorkload failed");
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
+ callback(errorStatus,
+ outputShapes,
+ g_NoTiming,
+ "ArmnnPreparedModel_1_2::Execute");
+ return errorStatus;
}
}
catch (armnn::Exception& e)
{
- ALOGW("armnn::Exception caught from EnqueueWorkload: %s", e.what());
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
+ ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
+ callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
catch (std::exception& e)
{
- ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming);
- return Void();
+ ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
+ callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
}
- DumpTensorsIfRequired("Output", *pOutputTensors);
+ return V1_0::ErrorStatus::NONE;
+}
+void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools)
+{
+ if (memPools.empty())
+ {
+ return;
+ }
// Commit output buffers.
// Note that we update *all* pools, even if they aren't actually used as outputs -
// this is simpler and is what the CpuExecutor does.
- for (android::nn::RunTimePoolInfo& pool : *pMemPools)
+ for (auto& pool : memPools)
{
// Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
// update() has been removed and flush() added.
- #if defined(ARMNN_ANDROID_R) // Use the new Android implementation.
- pool.flush();
- #else
- pool.update();
- #endif
- }
-
- ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() after Execution");
-
- if (measureTiming == MeasureTiming::YES)
- {
- driverEnd = Now();
- Timing timing;
- timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
- timing.timeInDriver = MicrosecondsDuration(driverEnd, driverStart);
- ALOGV("ArmnnPreparedModel_1_2::executeSynchronously timing Device = %lu Driver = %lu", timing.timeOnDevice,
- timing.timeInDriver);
- cb(V1_0::ErrorStatus::NONE, outputShapes, timing);
- }
- else
- {
- cb(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming);
+#if defined(ARMNN_ANDROID_R) // Use the new Android implementation.
+ pool.flush();
+#else
+ pool.update();
+#endif
}
- return Void();
}
-/// This class is strongly inspired by the default implementation in Android named DefaultBurstExecutorWithCache.
-/// The original code is licensed under Apache-2.0 and can be found at the following link:
-/// https://android.googlesource.com/platform/frameworks/
-/// ml/+/refs/tags/android-10.0.0_r20/nn/common/ExecutionBurstServer.cpp
-class ArmnnBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
-public:
- ArmnnBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel)
- : m_PreparedModel(preparedModel)
- {}
+template<typename HalVersion>
+Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request,
+ MeasureTiming measureTiming,
+ executeSynchronously_cb cb)
+{
+ ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
+ m_RequestCount++;
- bool isCacheEntryPresent(int32_t slot) const override
+ if (cb == nullptr)
{
- const auto it = m_MemoryCache.find(slot);
- return (it != m_MemoryCache.end()) && it->second.valid();
+ ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid callback passed");
+ return Void();
}
- void addCacheEntry(const hidl_memory& memory, int32_t slot) override
- {
- m_MemoryCache[slot] = memory;
- }
+ TimePoint driverStart;
- void removeCacheEntry(int32_t slot) override
+ if (measureTiming == MeasureTiming::YES)
{
- m_MemoryCache.erase(slot);
+ driverStart = Now();
}
- std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
- const V1_0::Request& request, const std::vector<int32_t>& slots,
- MeasureTiming measure) override
+ if (!android::nn::validateRequest(request, m_Model))
{
- ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache::execute");
- hidl_vec<hidl_memory> pools(slots.size());
-
- std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot)
- {
- return m_MemoryCache[slot];
- });
-
- V1_0::Request fullRequest = request;
- fullRequest.pools = std::move(pools);
+ ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model");
+ cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming);
+ return Void();
+ }
- // Setup Callback
- V1_0::ErrorStatus returnedStatus = V1_0::ErrorStatus::GENERAL_FAILURE;
- hidl_vec<OutputShape> returnedOutputShapes;
- Timing returnedTiming;
- auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](V1_0::ErrorStatus status,
- const hidl_vec<OutputShape>& outputShapes,
- const Timing& timing)
+ auto cbWrapper = [cb](V1_0::ErrorStatus errorStatus,
+ std::vector<OutputShape> outputShapes,
+ const Timing& timing,
+ std::string)
{
- returnedStatus = status;
- returnedOutputShapes = outputShapes;
- returnedTiming = timing;
+ cb(errorStatus, outputShapes, timing);
};
- // Execute
- ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache executing");
- const Return<void> ret = m_PreparedModel->executeSynchronously(fullRequest, measure, cb);
+ // map the memory pool into shared pointers
+ // use a shared memory pools vector on the heap, as it is passed to the request thread
+ auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
- if (!ret.isOk() || returnedStatus != V1_0::ErrorStatus::NONE)
- {
- ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::error executing");
- }
- return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
- }
+ // allocate the tensors on the heap, as they are passed to the request thread
+ auto inputs = std::make_shared<armnn::InputTensors>();
+ auto outputs = std::make_shared<armnn::OutputTensors>();
-private:
- V1_2::IPreparedModel* const m_PreparedModel;
- std::map<int, hidl_memory> m_MemoryCache;
-};
+ auto prepareStatus = PrepareMemoryForIO(*inputs, *outputs, *memPools, request, cbWrapper);
+ if (prepareStatus != V1_0::ErrorStatus::NONE)
+ {
+ return Void();
+ }
+ ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution");
-template<typename HalVersion>
-Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
- const sp<V1_2::IBurstCallback>& callback,
- const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
- V1_2::IPreparedModel::configureExecutionBurst_cb cb)
-{
- ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst");
- const std::shared_ptr<ArmnnBurstExecutorWithCache> executorWithCache =
- std::make_shared<ArmnnBurstExecutorWithCache>(this);
- const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback,
- requestChannel,
- resultChannel,
- executorWithCache);
+ CallbackContext_1_2 cbCtx;
+ cbCtx.callback = cbWrapper;
+ cbCtx.ctx.measureTimings = measureTiming;
+ cbCtx.ctx.driverStart = driverStart;
+ ExecuteGraph(memPools, *inputs, *outputs, cbCtx);
- if (burst == nullptr)
- {
- cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
- }
- else
- {
- cb(V1_0::ErrorStatus::NONE, burst);
- }
return Void();
}
template<typename HalVersion>
-void ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
+template<typename CallbackContext>
+bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
- std::shared_ptr<armnn::InputTensors>& pInputTensors,
- std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
- ArmnnCallback_1_2 cb)
+ armnn::InputTensors& inputTensors,
+ armnn::OutputTensors& outputTensors,
+ CallbackContext cb)
{
ALOGV("ArmnnPreparedModel_1_2::ExecuteGraph(...)");
TimePoint driverEnd, deviceStart, deviceEnd;
- DumpTensorsIfRequired("Input", *pInputTensors);
+ DumpTensorsIfRequired("Input", inputTensors);
- std::vector<std::pair<int, armnn::Tensor> > outputTensors = *pOutputTensors.get();
std::vector<OutputShape> outputShapes(outputTensors.size());
-
for (unsigned int i = 0; i < outputTensors.size(); i++)
{
std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
const armnn::Tensor outputTensor = outputTensorPair.second;
const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
- hidl_vec<uint32_t> dimensions;
-
- armnn::TensorShape tensorShape = outputTensorInfo.GetShape();
- const unsigned int numDims = tensorShape.GetNumDimensions();
- dimensions.resize(numDims);
-
- for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
- {
- dimensions[outputIdx] = tensorShape[outputIdx];
- }
- outputShapes[i].dimensions = dimensions;
- outputShapes[i].isSufficient = true;
+ outputShapes[i] = ComputeShape(outputTensorInfo);
}
// run it
try
{
- if (cb.measureTiming == MeasureTiming::YES)
+ if (cb.ctx.measureTimings == MeasureTiming::YES)
{
deviceStart = Now();
}
- armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
+ armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
- if (cb.measureTiming == MeasureTiming::YES)
+ if (cb.ctx.measureTimings == MeasureTiming::YES)
{
deviceEnd = Now();
}
@@ -556,48 +473,40 @@ void ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
ALOGW("EnqueueWorkload failed");
cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming,
"ArmnnPreparedModel_1_2::ExecuteGraph");
- return;
+ return false;
}
}
catch (armnn::Exception& e)
{
ALOGW("armnn:Exception caught from EnqueueWorkload: %s", e.what());
cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
- return;
+ return false;
}
catch (std::exception& e)
{
ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
- return;
+ return false;
}
- DumpTensorsIfRequired("Output", *pOutputTensors);
+ CommitPools(*pMemPools);
- // Commit output buffers.
- // Note that we update *all* pools, even if they aren't actually used as outputs -
- // this is simpler and is what the CpuExecutor does.
- for (android::nn::RunTimePoolInfo& pool : *pMemPools)
- {
- // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where
- // update() has been removed and flush() added.
- #if defined(ARMNN_ANDROID_R) // Use the new Android implementation.
- pool.flush();
- #else
- pool.update();
- #endif
- }
+ DumpTensorsIfRequired("Output", outputTensors);
- if (cb.measureTiming == MeasureTiming::YES)
+ if (cb.ctx.measureTimings == MeasureTiming::YES)
{
driverEnd = Now();
Timing timing;
timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
- timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.driverStart);
- cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ExecuteGraph");
+ timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.ctx.driverStart);
+ ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
+ timing.timeInDriver);
+ cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph");
} else {
- cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ExecuteGraph");
+ cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
}
+
+ return true;
}
template<typename HalVersion>
@@ -624,38 +533,27 @@ bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteWithDummyInputs()
outputTensors.emplace_back(i, outputTensor);
}
- try
- {
- armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
- if (status != armnn::Status::Success)
- {
- ALOGW("ExecuteWithDummyInputs: EnqueueWorkload failed");
- return false;
- }
- }
- catch (armnn::Exception& e)
- {
- ALOGW("ExecuteWithDummyInputs: armnn::Exception caught from EnqueueWorkload: %s", e.what());
- return false;
- }
- catch (std::exception& e)
- {
- ALOGE("ExecuteWithDummyInputs: std::exception caught from EnqueueWorkload: %s", e.what());
- return false;
- }
- return true;
+ auto nullCallback = [](V1_0::ErrorStatus, std::vector<OutputShape>, const Timing&, std::string) {};
+ CallbackContext_1_2 callbackContext;
+ callbackContext.callback = nullCallback;
+ callbackContext.ctx.measureTimings = MeasureTiming::NO;
+ auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
+ return ExecuteGraph(memPools,
+ inputTensors,
+ outputTensors,
+ callbackContext);
}
template<typename HalVersion>
Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_0::Request& request,
MeasureTiming measureTiming,
- armnnExecuteCallback_1_2 callback)
+ CallbackAsync_1_2 callback)
{
- TimePoint driverStart;
-
+ ExecutionContext_1_2 ctx;
if (measureTiming == MeasureTiming::YES)
{
- driverStart = Now();
+ ctx.measureTimings = measureTiming;
+ ctx.driverStart = Now();
}
ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str());
@@ -672,111 +570,142 @@ Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_
ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&callback));
}
- // allocate the tensors on the heap, as they are passed to the request thread
- auto pInputTensors = std::make_shared<armnn::InputTensors>();
- auto pOutputTensors = std::make_shared<armnn::OutputTensors>();
-
// map the memory pool into shared pointers
// use a shared memory pools vector on the heap, as it is passed to the request thread
- auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
+ auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
- if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools))
+ // allocate the tensors on the heap, as they are passed to the request thread
+ auto inputTensors = std::make_shared<armnn::InputTensors>();
+ auto outputTensors = std::make_shared<armnn::OutputTensors>();
+
+ auto prepareStatus = PrepareMemoryForIO(*inputTensors, *outputTensors, *memPools, request, callback);
+ switch(prepareStatus)
{
- callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
- return V1_0::ErrorStatus::GENERAL_FAILURE;
+ case V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
+ return V1_0::ErrorStatus::NONE;
+ case V1_0::ErrorStatus::GENERAL_FAILURE:
+ return V1_0::ErrorStatus::GENERAL_FAILURE;
+ default:
+ {}
}
- // add the inputs and outputs with their data
- try
- {
- pInputTensors->reserve(request.inputs.size());
- for (unsigned int i = 0; i < request.inputs.size(); i++)
- {
- const auto& inputArg = request.inputs[i];
+ ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
- const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
- const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, *pMemPools);
+ // post the request for asynchronous execution
+ CallbackContext_1_2 cb;
+ cb.callback = callback;
+ cb.ctx = ctx;
+ m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
+ ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
+ return V1_0::ErrorStatus::NONE;
+}
- if (inputTensor.GetMemoryArea() == nullptr)
- {
- ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
- callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
- return V1_0::ErrorStatus::GENERAL_FAILURE;
- }
- pInputTensors->emplace_back(i, inputTensor);
- }
+/// This class is strongly inspired by the default implementation in Android named DefaultBurstExecutorWithCache.
+/// The original code is licensed under Apache-2.0 and can be found at the following link:
+/// https://android.googlesource.com/platform/frameworks/
+/// ml/+/refs/tags/android-10.0.0_r20/nn/common/ExecutionBurstServer.cpp
+class ArmnnBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
+public:
+ ArmnnBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel)
+ : m_PreparedModel(preparedModel)
+ {}
- pOutputTensors->reserve(request.outputs.size());
- std::vector<OutputShape> outputShapes(request.outputs.size());
+ bool isCacheEntryPresent(int32_t slot) const override
+ {
+ const auto it = m_MemoryCache.find(slot);
+ return (it != m_MemoryCache.end()) && it->second.valid();
+ }
- for (unsigned int i = 0; i < request.outputs.size(); i++)
- {
- const auto& outputArg = request.outputs[i];
+ void addCacheEntry(const hidl_memory& memory, int32_t slot) override
+ {
+ m_MemoryCache[slot] = memory;
+ }
- const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
- const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, *pMemPools);
- if (outputTensor.GetMemoryArea() == nullptr)
- {
- ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
- callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
- return V1_0::ErrorStatus::GENERAL_FAILURE;
- }
+ void removeCacheEntry(int32_t slot) override
+ {
+ m_MemoryCache.erase(slot);
+ }
- const size_t outputSize = outputTensorInfo.GetNumBytes();
- const size_t bufferSize = pMemPools->at(outputArg.location.poolIndex).getHidlMemory().size();
- pOutputTensors->emplace_back(i, outputTensor);
+ std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
+ const V1_0::Request& request, const std::vector<int32_t>& slots,
+ MeasureTiming measure) override
+ {
+ ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache::execute");
+ hidl_vec<hidl_memory> pools(slots.size());
- hidl_vec<uint32_t> dimensions;
+ std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot)
+ {
+ return m_MemoryCache[slot];
+ });
- armnn::TensorShape tensorShape = outputTensorInfo.GetShape();
- const unsigned int numDims = tensorShape.GetNumDimensions();
- dimensions.resize(numDims);
+ V1_0::Request fullRequest = request;
+ fullRequest.pools = std::move(pools);
- for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx)
+ // Setup Callback
+ V1_0::ErrorStatus returnedStatus = V1_0::ErrorStatus::GENERAL_FAILURE;
+ hidl_vec<OutputShape> returnedOutputShapes;
+ Timing returnedTiming;
+ auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](V1_0::ErrorStatus status,
+ const hidl_vec<OutputShape>& outputShapes,
+ const Timing& timing)
{
- dimensions[outputIdx] = tensorShape[outputIdx];
- }
- outputShapes[i].dimensions = dimensions;
- outputShapes[i].isSufficient = bufferSize >= outputSize;
+ returnedStatus = status;
+ returnedOutputShapes = outputShapes;
+ returnedTiming = timing;
+ };
- if (bufferSize < outputSize)
- {
- ALOGW("ArmnnPreparedModel_1_2::Execute failed");
- callback(V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE,
- outputShapes,
- g_NoTiming,
- "ArmnnPreparedModel_1_2::Execute");
- return V1_0::ErrorStatus::NONE;
- }
+ // Execute
+ ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache executing");
+ const Return<void> ret = m_PreparedModel->executeSynchronously(fullRequest, measure, cb);
+
+ if (!ret.isOk() || returnedStatus != V1_0::ErrorStatus::NONE)
+ {
+ ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::error executing");
}
+ return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming);
}
- catch (armnn::Exception& e)
+
+private:
+ V1_2::IPreparedModel* const m_PreparedModel;
+ std::map<int, hidl_memory> m_MemoryCache;
+};
+
+template<typename HalVersion>
+Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
+ const sp<V1_2::IBurstCallback>& callback,
+ const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
+ V1_2::IPreparedModel::configureExecutionBurst_cb cb)
+{
+ ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst");
+ const std::shared_ptr<ArmnnBurstExecutorWithCache> executorWithCache =
+ std::make_shared<ArmnnBurstExecutorWithCache>(this);
+ const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback,
+ requestChannel,
+ resultChannel,
+ executorWithCache);
+
+ if (burst == nullptr)
{
- ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
- callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
- return V1_0::ErrorStatus::GENERAL_FAILURE;
+ cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
}
- catch (std::exception& e)
+ else
{
- ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
- callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
- return V1_0::ErrorStatus::GENERAL_FAILURE;
+ cb(V1_0::ErrorStatus::NONE, burst);
}
-
- ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
- // post the request for asynchronous execution
- ArmnnCallback_1_2 armnnCb;
- armnnCb.callback = callback;
- armnnCb.measureTiming = measureTiming;
- armnnCb.driverStart = driverStart;
- m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
- ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
- return V1_0::ErrorStatus::NONE;
+ return Void();
}
+
+
#ifdef ARMNN_ANDROID_NN_V1_2
template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>;
+template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>(
+ std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ armnn::InputTensors& pInputTensors,
+ armnn::OutputTensors& pOutputTensors,
+ CallbackContext_1_2 cb);
#endif
} // namespace armnn_driver
diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp
index f609ef7..e68614a 100644
--- a/ArmnnPreparedModel_1_2.hpp
+++ b/ArmnnPreparedModel_1_2.hpp
@@ -19,18 +19,21 @@
namespace armnn_driver
{
-typedef std::function<void(::android::hardware::neuralnetworks::V1_0::ErrorStatus status,
- std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
- const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
- std::string callingFunction)> armnnExecuteCallback_1_2;
+using CallbackAsync_1_2 = std::function<
+ void(V1_0::ErrorStatus errorStatus,
+ std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
+ const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
+ std::string callingFunction)>;
-struct ArmnnCallback_1_2
+struct ExecutionContext_1_2
{
- armnnExecuteCallback_1_2 callback;
+ ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings =
+ ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
TimePoint driverStart;
- MeasureTiming measureTiming;
};
+using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
+
template <typename HalVersion>
class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
{
@@ -62,19 +65,38 @@ public:
configureExecutionBurst_cb cb) override;
/// execute the graph prepared from the request
- void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
- std::shared_ptr<armnn::InputTensors>& pInputTensors,
- std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
- ArmnnCallback_1_2 callbackDescriptor);
+ template<typename CallbackContext>
+ bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
+ armnn::InputTensors& inputTensors,
+ armnn::OutputTensors& outputTensors,
+ CallbackContext callback);
/// Executes this model with dummy inputs (e.g. all zeroes).
/// \return false on failure, otherwise true
bool ExecuteWithDummyInputs();
private:
- Return <V1_0::ErrorStatus> Execute(const V1_0::Request& request,
- MeasureTiming measureTiming,
- armnnExecuteCallback_1_2 callback);
+ Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
+ MeasureTiming measureTiming,
+ CallbackAsync_1_2 callback);
+
+ Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
+ armnn::InputTensors& inputs,
+ const V1_0::Request& request,
+ const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+ Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
+ armnn::OutputTensors& outputs,
+ std::vector<OutputShape> &outputShapes,
+ const V1_0::Request& request,
+ const std::vector<android::nn::RunTimePoolInfo>& memPools);
+
+ Return <V1_0::ErrorStatus> PrepareMemoryForIO(
+ armnn::InputTensors& inputs,
+ armnn::OutputTensors& outputs,
+ std::vector<android::nn::RunTimePoolInfo>& memPools,
+ const V1_0::Request& request,
+ CallbackAsync_1_2 callback);
template <typename TensorBindingCollection>
void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
@@ -84,7 +106,9 @@ private:
V1_2::Model m_Model;
// There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
// It is specific to this class, so it is declared as static here
- static RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> m_RequestThread;
+ static RequestThread<ArmnnPreparedModel_1_2,
+ HalVersion,
+ CallbackContext_1_2> m_RequestThread;
uint32_t m_RequestCount;
const std::string& m_RequestInputsAndOutputsDumpDir;
const bool m_GpuProfilingEnabled;
diff --git a/RequestThread.cpp b/RequestThread.cpp
index 052c5c1..22a3ac3 100644
--- a/RequestThread.cpp
+++ b/RequestThread.cpp
@@ -21,15 +21,15 @@ using namespace android;
namespace armnn_driver
{
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-RequestThread<PreparedModel, HalVersion, Callback>::RequestThread()
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+RequestThread<PreparedModel, HalVersion, CallbackContext>::RequestThread()
{
ALOGV("RequestThread::RequestThread()");
m_Thread = std::make_unique<std::thread>(&RequestThread::Process, this);
}
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-RequestThread<PreparedModel, HalVersion, Callback>::~RequestThread()
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+RequestThread<PreparedModel, HalVersion, CallbackContext>::~RequestThread()
{
ALOGV("RequestThread::~RequestThread()");
@@ -54,25 +54,25 @@ RequestThread<PreparedModel, HalVersion, Callback>::~RequestThread()
catch (const std::exception&) { } // Swallow any exception.
}
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(PreparedModel<HalVersion>* model,
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread<PreparedModel, HalVersion, CallbackContext>::PostMsg(PreparedModel<HalVersion>* model,
std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
std::shared_ptr<armnn::InputTensors>& inputTensors,
std::shared_ptr<armnn::OutputTensors>& outputTensors,
- Callback callback)
+ CallbackContext callbackContext)
{
ALOGV("RequestThread::PostMsg(...)");
auto data = std::make_shared<AsyncExecuteData>(model,
memPools,
inputTensors,
outputTensors,
- callback);
+ callbackContext);
auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::REQUEST, data);
PostMsg(pMsg);
}
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg)
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread<PreparedModel, HalVersion, CallbackContext>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg)
{
ALOGV("RequestThread::PostMsg(pMsg)");
// Add a message to the queue and notify the request thread
@@ -81,8 +81,8 @@ void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(std::shared_ptr
m_Cv.notify_one();
}
-template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
-void RequestThread<PreparedModel, HalVersion, Callback>::Process()
+template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
+void RequestThread<PreparedModel, HalVersion, CallbackContext>::Process()
{
ALOGV("RequestThread::Process()");
while (true)
@@ -109,9 +109,9 @@ void RequestThread<PreparedModel, HalVersion, Callback>::Process()
// invoke the asynchronous execution method
PreparedModel<HalVersion>* model = pMsg->data->m_Model;
model->ExecuteGraph(pMsg->data->m_MemPools,
- pMsg->data->m_InputTensors,
- pMsg->data->m_OutputTensors,
- pMsg->data->m_Callback);
+ *(pMsg->data->m_InputTensors),
+ *(pMsg->data->m_OutputTensors),
+ pMsg->data->m_CallbackContext);
break;
}
@@ -139,16 +139,16 @@ void RequestThread<PreparedModel, HalVersion, Callback>::Process()
/// Class template specializations
///
-template class RequestThread<ArmnnPreparedModel, hal_1_0::HalPolicy, ArmnnCallback_1_0>;
+template class RequestThread<ArmnnPreparedModel, hal_1_0::HalPolicy, CallbackContext_1_0>;
#ifdef ARMNN_ANDROID_NN_V1_1
-template class RequestThread<armnn_driver::ArmnnPreparedModel, hal_1_1::HalPolicy, ArmnnCallback_1_0>;
+template class RequestThread<armnn_driver::ArmnnPreparedModel, hal_1_1::HalPolicy, CallbackContext_1_0>;
#endif
#ifdef ARMNN_ANDROID_NN_V1_2
-template class RequestThread<ArmnnPreparedModel, hal_1_1::HalPolicy, ArmnnCallback_1_0>;
-template class RequestThread<ArmnnPreparedModel, hal_1_2::HalPolicy, ArmnnCallback_1_0>;
-template class RequestThread<ArmnnPreparedModel_1_2, hal_1_2::HalPolicy, ArmnnCallback_1_2>;
+template class RequestThread<ArmnnPreparedModel, hal_1_1::HalPolicy, CallbackContext_1_0>;
+template class RequestThread<ArmnnPreparedModel, hal_1_2::HalPolicy, CallbackContext_1_0>;
+template class RequestThread<ArmnnPreparedModel_1_2, hal_1_2::HalPolicy, CallbackContext_1_2>;
#endif
} // namespace armnn_driver
diff --git a/RequestThread.hpp b/RequestThread.hpp
index 253d104..79f309a 100644
--- a/RequestThread.hpp
+++ b/RequestThread.hpp
@@ -21,7 +21,7 @@ namespace armnn_driver
using TimePoint = std::chrono::steady_clock::time_point;
static const TimePoint g_Min = std::chrono::steady_clock::time_point::min();
-template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback>
+template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
class RequestThread
{
public:
@@ -41,7 +41,7 @@ public:
std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
std::shared_ptr<armnn::InputTensors>& inputTensors,
std::shared_ptr<armnn::OutputTensors>& outputTensors,
- Callback callback);
+ CallbackContext callbackContext);
private:
RequestThread(const RequestThread&) = delete;
@@ -54,12 +54,12 @@ private:
std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
std::shared_ptr<armnn::InputTensors>& inputTensors,
std::shared_ptr<armnn::OutputTensors>& outputTensors,
- Callback callback)
+ CallbackContext callbackContext)
: m_Model(model)
, m_MemPools(memPools)
, m_InputTensors(inputTensors)
, m_OutputTensors(outputTensors)
- , m_Callback(callback)
+ , m_CallbackContext(callbackContext)
{
}
@@ -67,9 +67,8 @@ private:
std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
std::shared_ptr<armnn::InputTensors> m_InputTensors;
std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
- Callback m_Callback;
+ CallbackContext m_CallbackContext;
};
-
enum class ThreadMsgType
{
EXIT, // exit the thread