diff options
-rw-r--r-- | ArmnnDriverImpl.hpp | 7 | ||||
-rw-r--r-- | ArmnnPreparedModel.cpp | 17 | ||||
-rw-r--r-- | ArmnnPreparedModel.hpp | 12 | ||||
-rw-r--r-- | ArmnnPreparedModel_1_2.cpp | 649 | ||||
-rw-r--r-- | ArmnnPreparedModel_1_2.hpp | 54 | ||||
-rw-r--r-- | RequestThread.cpp | 40 | ||||
-rw-r--r-- | RequestThread.hpp | 11 |
7 files changed, 377 insertions, 413 deletions
diff --git a/ArmnnDriverImpl.hpp b/ArmnnDriverImpl.hpp index c5b17781..dfaafb34 100644 --- a/ArmnnDriverImpl.hpp +++ b/ArmnnDriverImpl.hpp @@ -23,6 +23,13 @@ namespace V1_2 = ::android::hardware::neuralnetworks::V1_2; namespace armnn_driver { +template <typename Callback, typename Context> +struct CallbackContext +{ + Callback callback; + Context ctx; +}; + template<typename HalPolicy> class ArmnnDriverImpl { diff --git a/ArmnnPreparedModel.cpp b/ArmnnPreparedModel.cpp index 2cd560d7..d095e419 100644 --- a/ArmnnPreparedModel.cpp +++ b/ArmnnPreparedModel.cpp @@ -84,7 +84,8 @@ using namespace android::hardware; namespace armnn_driver { template<typename HalVersion> -RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> ArmnnPreparedModel<HalVersion>::m_RequestThread; +RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> + ArmnnPreparedModel<HalVersion>::m_RequestThread; template<typename HalVersion> template <typename TensorBindingCollection> @@ -226,7 +227,7 @@ Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute( NotifyCallbackAndCheck(callback, errorStatus, callingFunction); }; - ArmnnCallback_1_0 armnnCb; + CallbackContext_1_0 armnnCb; armnnCb.callback = cb; // post the request for asynchronous execution m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb); @@ -237,18 +238,18 @@ Return<V1_0::ErrorStatus> ArmnnPreparedModel<HalVersion>::execute( template<typename HalVersion> void ArmnnPreparedModel<HalVersion>::ExecuteGraph( std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, - std::shared_ptr<armnn::InputTensors>& pInputTensors, - std::shared_ptr<armnn::OutputTensors>& pOutputTensors, - ArmnnCallback_1_0 cb) + armnn::InputTensors& inputTensors, + armnn::OutputTensors& outputTensors, + CallbackContext_1_0 cb) { ALOGV("ArmnnPreparedModel::ExecuteGraph(...)"); - DumpTensorsIfRequired("Input", *pInputTensors); + DumpTensorsIfRequired("Input", inputTensors); // run it try { - armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors); + armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); if (status != armnn::Status::Success) { ALOGW("EnqueueWorkload failed"); @@ -269,7 +270,7 @@ void ArmnnPreparedModel<HalVersion>::ExecuteGraph( return; } - DumpTensorsIfRequired("Output", *pOutputTensors); + DumpTensorsIfRequired("Output", outputTensors); // Commit output buffers. // Note that we update *all* pools, even if they aren't actually used as outputs - diff --git a/ArmnnPreparedModel.hpp b/ArmnnPreparedModel.hpp index 270a9339..89f6226f 100644 --- a/ArmnnPreparedModel.hpp +++ b/ArmnnPreparedModel.hpp @@ -24,6 +24,10 @@ struct ArmnnCallback_1_0 armnnExecuteCallback_1_0 callback; }; +struct ExecutionContext_1_0 {}; + +using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>; + template <typename HalVersion> class ArmnnPreparedModel : public V1_0::IPreparedModel { @@ -43,9 +47,9 @@ public: /// execute the graph prepared from the request void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, - std::shared_ptr<armnn::InputTensors>& pInputTensors, - std::shared_ptr<armnn::OutputTensors>& pOutputTensors, - ArmnnCallback_1_0 callback); + armnn::InputTensors& inputTensors, + armnn::OutputTensors& outputTensors, + CallbackContext_1_0 callback); /// Executes this model with dummy inputs (e.g. all zeroes). /// \return false on failure, otherwise true @@ -60,7 +64,7 @@ private: HalModel m_Model; // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads // It is specific to this class, so it is declared as static here - static RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> m_RequestThread; + static RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> m_RequestThread; uint32_t m_RequestCount; const std::string& m_RequestInputsAndOutputsDumpDir; const bool m_GpuProfilingEnabled; diff --git a/ArmnnPreparedModel_1_2.cpp b/ArmnnPreparedModel_1_2.cpp index 9b790443..5031c5ff 100644 --- a/ArmnnPreparedModel_1_2.cpp +++ b/ArmnnPreparedModel_1_2.cpp @@ -120,7 +120,7 @@ namespace armnn_driver { template<typename HalVersion> -RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> +RequestThread<ArmnnPreparedModel_1_2, HalVersion, CallbackContext_1_2> ArmnnPreparedModel_1_2<HalVersion>::m_RequestThread; template<typename HalVersion> @@ -215,339 +215,256 @@ Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute_1_2( return Execute(request, measureTiming, cb); } -template<typename HalVersion> -Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request, - MeasureTiming measureTiming, - executeSynchronously_cb cb) +OutputShape ComputeShape(const armnn::TensorInfo& info) { - ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str()); - m_RequestCount++; - - if (cb == nullptr) - { - ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid callback passed"); - return Void(); - } + OutputShape shape; - TimePoint driverStart, driverEnd, deviceStart, deviceEnd; + hidl_vec<uint32_t> dimensions; - if (measureTiming == MeasureTiming::YES) - { - driverStart = Now(); - } + armnn::TensorShape tensorShape = info.GetShape(); + const unsigned int numDims = tensorShape.GetNumDimensions(); + dimensions.resize(numDims); - if (!android::nn::validateRequest(request, m_Model)) + for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx) { - ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model"); - cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming); - return Void(); + dimensions[outputIdx] = tensorShape[outputIdx]; } - // allocate the tensors on the heap, as they are passed to the request thread - auto pInputTensors = std::make_shared<armnn::InputTensors>(); - auto pOutputTensors = std::make_shared<armnn::OutputTensors>(); - - // map the memory pool into shared pointers - // use a shared memory pools vector on the heap, as it is passed to the request thread - auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>(); + shape.dimensions = dimensions; + shape.isSufficient = true; - if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools)) - { - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); - } - std::vector<OutputShape> outputShapes(request.outputs.size()); + return shape; +} - try +template<typename HalVersion> +Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForInputs( + armnn::InputTensors& inputs, + const V1_0::Request& request, + const std::vector<android::nn::RunTimePoolInfo>& memPools) +{ + inputs.reserve(request.inputs.size()); + for (unsigned int i = 0; i < request.inputs.size(); i++) { - pInputTensors->reserve(request.inputs.size()); - for (unsigned int i = 0; i < request.inputs.size(); i++) - { - const auto& inputArg = request.inputs[i]; - - const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i); - const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, *pMemPools); + const auto& inputArg = request.inputs[i]; - if (inputTensor.GetMemoryArea() == nullptr) - { - ALOGE("Cannot execute request. Error converting request input %u to tensor", i); - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); - } + const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i); + const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, memPools); - pInputTensors->emplace_back(i, inputTensor); + if (inputTensor.GetMemoryArea() == nullptr) + { + ALOGE("Cannot execute request. Error converting request input %u to tensor", i); + return V1_0::ErrorStatus::GENERAL_FAILURE; } - pOutputTensors->reserve(request.outputs.size()); - for (unsigned int i = 0; i < request.outputs.size(); i++) - { - const auto& outputArg = request.outputs[i]; + inputs.emplace_back(i, inputTensor); + } - const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i); - const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, *pMemPools); + return V1_0::ErrorStatus::NONE; +} - if (outputTensor.GetMemoryArea() == nullptr) - { - ALOGE("Cannot execute request. Error converting request output %u to tensor", i); - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); - } - const size_t outputSize = outputTensorInfo.GetNumBytes(); - const size_t bufferSize = pMemPools->at(outputArg.location.poolIndex).getHidlMemory().size(); +template<typename HalVersion> +Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForOutputs( + armnn::OutputTensors& outputs, + std::vector<OutputShape> &outputShapes, + const V1_0::Request& request, + const std::vector<android::nn::RunTimePoolInfo>& memPools) +{ + outputs.reserve(request.outputs.size()); + for (unsigned int i = 0; i < request.outputs.size(); i++) + { + const auto& outputArg = request.outputs[i]; - hidl_vec<uint32_t> dimensions; + const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i); + const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, memPools); + if (outputTensor.GetMemoryArea() == nullptr) + { + ALOGE("Cannot execute request. Error converting request output %u to tensor", i); + return V1_0::ErrorStatus::GENERAL_FAILURE; + } - armnn::TensorShape tensorShape = outputTensorInfo.GetShape(); - const unsigned int numDims = tensorShape.GetNumDimensions(); - dimensions.resize(numDims); + const size_t outputSize = outputTensorInfo.GetNumBytes(); + const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size(); + if (bufferSize < outputSize) + { + ALOGW("ArmnnPreparedModel_1_2::Execute failed"); + return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE; + } - for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx) - { - dimensions[outputIdx] = tensorShape[outputIdx]; - } - outputShapes[i].dimensions = dimensions; - outputShapes[i].isSufficient = bufferSize >= outputSize; + outputs.emplace_back(i, outputTensor); + outputShapes[i] = ComputeShape(outputTensorInfo); + } - if (bufferSize < outputSize) - { - ALOGW("ArmnnPreparedModel_1_2::Execute failed"); - cb(V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, outputShapes, g_NoTiming); - return Void(); - } + return V1_0::ErrorStatus::NONE; +} - pOutputTensors->emplace_back(i, outputTensor); - } - } - catch (armnn::Exception& e) - { - ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what()); - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); - } - catch (std::exception& e) +template<typename HalVersion> +Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForIO( + armnn::InputTensors& inputs, + armnn::OutputTensors& outputs, + std::vector<android::nn::RunTimePoolInfo>& memPools, + const V1_0::Request& request, + CallbackAsync_1_2 callback) +{ + if (!setRunTimePoolInfosFromHidlMemories(&memPools, request.pools)) { - ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what()); - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); + callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); + return V1_0::ErrorStatus::GENERAL_FAILURE; } - ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution"); - - DumpTensorsIfRequired("Input", *pInputTensors); - // run it + // add the inputs and outputs with their data try { - if (measureTiming == MeasureTiming::YES) + if (PrepareMemoryForInputs(inputs, request, memPools) != V1_0::ErrorStatus::NONE) { - deviceStart = Now(); + callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); + return V1_0::ErrorStatus::GENERAL_FAILURE; } - armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors); - - if (measureTiming == MeasureTiming::YES) - { - deviceEnd = Now(); - } + std::vector<OutputShape> outputShapes(request.outputs.size()); - if (status != armnn::Status::Success) + auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools); + if (errorStatus != V1_0::ErrorStatus::NONE) { - ALOGW("EnqueueWorkload failed"); - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); + callback(errorStatus, + outputShapes, + g_NoTiming, + "ArmnnPreparedModel_1_2::Execute"); + return errorStatus; } } catch (armnn::Exception& e) { - ALOGW("armnn::Exception caught from EnqueueWorkload: %s", e.what()); - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); + ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what()); + callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); + return V1_0::ErrorStatus::GENERAL_FAILURE; } catch (std::exception& e) { - ALOGE("std::exception caught from EnqueueWorkload: %s", e.what()); - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming); - return Void(); + ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what()); + callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); + return V1_0::ErrorStatus::GENERAL_FAILURE; } - DumpTensorsIfRequired("Output", *pOutputTensors); + return V1_0::ErrorStatus::NONE; +} +void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools) +{ + if (memPools.empty()) + { + return; + } // Commit output buffers. // Note that we update *all* pools, even if they aren't actually used as outputs - // this is simpler and is what the CpuExecutor does. - for (android::nn::RunTimePoolInfo& pool : *pMemPools) + for (auto& pool : memPools) { // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where // update() has been removed and flush() added. - #if defined(ARMNN_ANDROID_R) // Use the new Android implementation. - pool.flush(); - #else - pool.update(); - #endif - } - - ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() after Execution"); - - if (measureTiming == MeasureTiming::YES) - { - driverEnd = Now(); - Timing timing; - timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart); - timing.timeInDriver = MicrosecondsDuration(driverEnd, driverStart); - ALOGV("ArmnnPreparedModel_1_2::executeSynchronously timing Device = %lu Driver = %lu", timing.timeOnDevice, - timing.timeInDriver); - cb(V1_0::ErrorStatus::NONE, outputShapes, timing); - } - else - { - cb(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming); +#if defined(ARMNN_ANDROID_R) // Use the new Android implementation. + pool.flush(); +#else + pool.update(); +#endif } - return Void(); } -/// This class is strongly inspired by the default implementation in Android named DefaultBurstExecutorWithCache. -/// The original code is licensed under Apache-2.0 and can be found at the following link: -/// https://android.googlesource.com/platform/frameworks/ -/// ml/+/refs/tags/android-10.0.0_r20/nn/common/ExecutionBurstServer.cpp -class ArmnnBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache { -public: - ArmnnBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel) - : m_PreparedModel(preparedModel) - {} +template<typename HalVersion> +Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request, + MeasureTiming measureTiming, + executeSynchronously_cb cb) +{ + ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str()); + m_RequestCount++; - bool isCacheEntryPresent(int32_t slot) const override + if (cb == nullptr) { - const auto it = m_MemoryCache.find(slot); - return (it != m_MemoryCache.end()) && it->second.valid(); + ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid callback passed"); + return Void(); } - void addCacheEntry(const hidl_memory& memory, int32_t slot) override - { - m_MemoryCache[slot] = memory; - } + TimePoint driverStart; - void removeCacheEntry(int32_t slot) override + if (measureTiming == MeasureTiming::YES) { - m_MemoryCache.erase(slot); + driverStart = Now(); } - std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute( - const V1_0::Request& request, const std::vector<int32_t>& slots, - MeasureTiming measure) override + if (!android::nn::validateRequest(request, m_Model)) { - ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache::execute"); - hidl_vec<hidl_memory> pools(slots.size()); - - std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) - { - return m_MemoryCache[slot]; - }); - - V1_0::Request fullRequest = request; - fullRequest.pools = std::move(pools); + ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model"); + cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming); + return Void(); + } - // Setup Callback - V1_0::ErrorStatus returnedStatus = V1_0::ErrorStatus::GENERAL_FAILURE; - hidl_vec<OutputShape> returnedOutputShapes; - Timing returnedTiming; - auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](V1_0::ErrorStatus status, - const hidl_vec<OutputShape>& outputShapes, - const Timing& timing) + auto cbWrapper = [cb](V1_0::ErrorStatus errorStatus, + std::vector<OutputShape> outputShapes, + const Timing& timing, + std::string) { - returnedStatus = status; - returnedOutputShapes = outputShapes; - returnedTiming = timing; + cb(errorStatus, outputShapes, timing); }; - // Execute - ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache executing"); - const Return<void> ret = m_PreparedModel->executeSynchronously(fullRequest, measure, cb); + // map the memory pool into shared pointers + // use a shared memory pools vector on the heap, as it is passed to the request thread + auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>(); - if (!ret.isOk() || returnedStatus != V1_0::ErrorStatus::NONE) - { - ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::error executing"); - } - return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming); - } + // allocate the tensors on the heap, as they are passed to the request thread + auto inputs = std::make_shared<armnn::InputTensors>(); + auto outputs = std::make_shared<armnn::OutputTensors>(); -private: - V1_2::IPreparedModel* const m_PreparedModel; - std::map<int, hidl_memory> m_MemoryCache; -}; + auto prepareStatus = PrepareMemoryForIO(*inputs, *outputs, *memPools, request, cbWrapper); + if (prepareStatus != V1_0::ErrorStatus::NONE) + { + return Void(); + } + ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution"); -template<typename HalVersion> -Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst( - const sp<V1_2::IBurstCallback>& callback, - const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel, - const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel, - V1_2::IPreparedModel::configureExecutionBurst_cb cb) -{ - ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst"); - const std::shared_ptr<ArmnnBurstExecutorWithCache> executorWithCache = - std::make_shared<ArmnnBurstExecutorWithCache>(this); - const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback, - requestChannel, - resultChannel, - executorWithCache); + CallbackContext_1_2 cbCtx; + cbCtx.callback = cbWrapper; + cbCtx.ctx.measureTimings = measureTiming; + cbCtx.ctx.driverStart = driverStart; + ExecuteGraph(memPools, *inputs, *outputs, cbCtx); - if (burst == nullptr) - { - cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}); - } - else - { - cb(V1_0::ErrorStatus::NONE, burst); - } return Void(); } template<typename HalVersion> -void ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph( +template<typename CallbackContext> +bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph( std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, - std::shared_ptr<armnn::InputTensors>& pInputTensors, - std::shared_ptr<armnn::OutputTensors>& pOutputTensors, - ArmnnCallback_1_2 cb) + armnn::InputTensors& inputTensors, + armnn::OutputTensors& outputTensors, + CallbackContext cb) { ALOGV("ArmnnPreparedModel_1_2::ExecuteGraph(...)"); TimePoint driverEnd, deviceStart, deviceEnd; - DumpTensorsIfRequired("Input", *pInputTensors); + DumpTensorsIfRequired("Input", inputTensors); - std::vector<std::pair<int, armnn::Tensor> > outputTensors = *pOutputTensors.get(); std::vector<OutputShape> outputShapes(outputTensors.size()); - for (unsigned int i = 0; i < outputTensors.size(); i++) { std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i]; const armnn::Tensor outputTensor = outputTensorPair.second; const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo(); - hidl_vec<uint32_t> dimensions; - - armnn::TensorShape tensorShape = outputTensorInfo.GetShape(); - const unsigned int numDims = tensorShape.GetNumDimensions(); - dimensions.resize(numDims); - - for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx) - { - dimensions[outputIdx] = tensorShape[outputIdx]; - } - outputShapes[i].dimensions = dimensions; - outputShapes[i].isSufficient = true; + outputShapes[i] = ComputeShape(outputTensorInfo); } // run it try { - if (cb.measureTiming == MeasureTiming::YES) + if (cb.ctx.measureTimings == MeasureTiming::YES) { deviceStart = Now(); } - armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors); + armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); - if (cb.measureTiming == MeasureTiming::YES) + if (cb.ctx.measureTimings == MeasureTiming::YES) { deviceEnd = Now(); } @@ -556,48 +473,40 @@ void ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph( ALOGW("EnqueueWorkload failed"); cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph"); - return; + return false; } } catch (armnn::Exception& e) { ALOGW("armnn:Exception caught from EnqueueWorkload: %s", e.what()); cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph"); - return; + return false; } catch (std::exception& e) { ALOGE("std::exception caught from EnqueueWorkload: %s", e.what()); cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph"); - return; + return false; } - DumpTensorsIfRequired("Output", *pOutputTensors); + CommitPools(*pMemPools); - // Commit output buffers. - // Note that we update *all* pools, even if they aren't actually used as outputs - - // this is simpler and is what the CpuExecutor does. - for (android::nn::RunTimePoolInfo& pool : *pMemPools) - { - // Type android::nn::RunTimePoolInfo has changed between Android P & Q and Android R, where - // update() has been removed and flush() added. - #if defined(ARMNN_ANDROID_R) // Use the new Android implementation. - pool.flush(); - #else - pool.update(); - #endif - } + DumpTensorsIfRequired("Output", outputTensors); - if (cb.measureTiming == MeasureTiming::YES) + if (cb.ctx.measureTimings == MeasureTiming::YES) { driverEnd = Now(); Timing timing; timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart); - timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.driverStart); - cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ExecuteGraph"); + timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.ctx.driverStart); + ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice, + timing.timeInDriver); + cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph"); } else { - cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ExecuteGraph"); + cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph"); } + + return true; } template<typename HalVersion> @@ -624,38 +533,27 @@ bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteWithDummyInputs() outputTensors.emplace_back(i, outputTensor); } - try - { - armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors); - if (status != armnn::Status::Success) - { - ALOGW("ExecuteWithDummyInputs: EnqueueWorkload failed"); - return false; - } - } - catch (armnn::Exception& e) - { - ALOGW("ExecuteWithDummyInputs: armnn::Exception caught from EnqueueWorkload: %s", e.what()); - return false; - } - catch (std::exception& e) - { - ALOGE("ExecuteWithDummyInputs: std::exception caught from EnqueueWorkload: %s", e.what()); - return false; - } - return true; + auto nullCallback = [](V1_0::ErrorStatus, std::vector<OutputShape>, const Timing&, std::string) {}; + CallbackContext_1_2 callbackContext; + callbackContext.callback = nullCallback; + callbackContext.ctx.measureTimings = MeasureTiming::NO; + auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>(); + return ExecuteGraph(memPools, + inputTensors, + outputTensors, + callbackContext); } template<typename HalVersion> Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_0::Request& request, MeasureTiming measureTiming, - armnnExecuteCallback_1_2 callback) + CallbackAsync_1_2 callback) { - TimePoint driverStart; - + ExecutionContext_1_2 ctx; if (measureTiming == MeasureTiming::YES) { - driverStart = Now(); + ctx.measureTimings = measureTiming; + ctx.driverStart = Now(); } ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str()); @@ -672,111 +570,142 @@ Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_ ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&callback)); } - // allocate the tensors on the heap, as they are passed to the request thread - auto pInputTensors = std::make_shared<armnn::InputTensors>(); - auto pOutputTensors = std::make_shared<armnn::OutputTensors>(); - // map the memory pool into shared pointers // use a shared memory pools vector on the heap, as it is passed to the request thread - auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>(); + auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>(); - if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools)) + // allocate the tensors on the heap, as they are passed to the request thread + auto inputTensors = std::make_shared<armnn::InputTensors>(); + auto outputTensors = std::make_shared<armnn::OutputTensors>(); + + auto prepareStatus = PrepareMemoryForIO(*inputTensors, *outputTensors, *memPools, request, callback); + switch(prepareStatus) { - callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); - return V1_0::ErrorStatus::GENERAL_FAILURE; + case V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE: + return V1_0::ErrorStatus::NONE; + case V1_0::ErrorStatus::GENERAL_FAILURE: + return V1_0::ErrorStatus::GENERAL_FAILURE; + default: + {} } - // add the inputs and outputs with their data - try - { - pInputTensors->reserve(request.inputs.size()); - for (unsigned int i = 0; i < request.inputs.size(); i++) - { - const auto& inputArg = request.inputs[i]; + ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg"); - const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i); - const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, *pMemPools); + // post the request for asynchronous execution + CallbackContext_1_2 cb; + cb.callback = callback; + cb.ctx = ctx; + m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb); + ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg"); + return V1_0::ErrorStatus::NONE; +} - if (inputTensor.GetMemoryArea() == nullptr) - { - ALOGE("Cannot execute request. Error converting request input %u to tensor", i); - callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); - return V1_0::ErrorStatus::GENERAL_FAILURE; - } - pInputTensors->emplace_back(i, inputTensor); - } +/// This class is strongly inspired by the default implementation in Android named DefaultBurstExecutorWithCache. +/// The original code is licensed under Apache-2.0 and can be found at the following link: +/// https://android.googlesource.com/platform/frameworks/ +/// ml/+/refs/tags/android-10.0.0_r20/nn/common/ExecutionBurstServer.cpp +class ArmnnBurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache { +public: + ArmnnBurstExecutorWithCache(V1_2::IPreparedModel* preparedModel) + : m_PreparedModel(preparedModel) + {} - pOutputTensors->reserve(request.outputs.size()); - std::vector<OutputShape> outputShapes(request.outputs.size()); + bool isCacheEntryPresent(int32_t slot) const override + { + const auto it = m_MemoryCache.find(slot); + return (it != m_MemoryCache.end()) && it->second.valid(); + } - for (unsigned int i = 0; i < request.outputs.size(); i++) - { - const auto& outputArg = request.outputs[i]; + void addCacheEntry(const hidl_memory& memory, int32_t slot) override + { + m_MemoryCache[slot] = memory; + } - const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i); - const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, *pMemPools); - if (outputTensor.GetMemoryArea() == nullptr) - { - ALOGE("Cannot execute request. Error converting request output %u to tensor", i); - callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); - return V1_0::ErrorStatus::GENERAL_FAILURE; - } + void removeCacheEntry(int32_t slot) override + { + m_MemoryCache.erase(slot); + } - const size_t outputSize = outputTensorInfo.GetNumBytes(); - const size_t bufferSize = pMemPools->at(outputArg.location.poolIndex).getHidlMemory().size(); - pOutputTensors->emplace_back(i, outputTensor); + std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute( + const V1_0::Request& request, const std::vector<int32_t>& slots, + MeasureTiming measure) override + { + ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache::execute"); + hidl_vec<hidl_memory> pools(slots.size()); - hidl_vec<uint32_t> dimensions; + std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) + { + return m_MemoryCache[slot]; + }); - armnn::TensorShape tensorShape = outputTensorInfo.GetShape(); - const unsigned int numDims = tensorShape.GetNumDimensions(); - dimensions.resize(numDims); + V1_0::Request fullRequest = request; + fullRequest.pools = std::move(pools); - for (unsigned int outputIdx = 0u; outputIdx < numDims; ++outputIdx) + // Setup Callback + V1_0::ErrorStatus returnedStatus = V1_0::ErrorStatus::GENERAL_FAILURE; + hidl_vec<OutputShape> returnedOutputShapes; + Timing returnedTiming; + auto cb = [&returnedStatus, &returnedOutputShapes, &returnedTiming](V1_0::ErrorStatus status, + const hidl_vec<OutputShape>& outputShapes, + const Timing& timing) { - dimensions[outputIdx] = tensorShape[outputIdx]; - } - outputShapes[i].dimensions = dimensions; - outputShapes[i].isSufficient = bufferSize >= outputSize; + returnedStatus = status; + returnedOutputShapes = outputShapes; + returnedTiming = timing; + }; - if (bufferSize < outputSize) - { - ALOGW("ArmnnPreparedModel_1_2::Execute failed"); - callback(V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, - outputShapes, - g_NoTiming, - "ArmnnPreparedModel_1_2::Execute"); - return V1_0::ErrorStatus::NONE; - } + // Execute + ALOGV("ArmnnPreparedModel_1_2::BurstExecutorWithCache executing"); + const Return<void> ret = m_PreparedModel->executeSynchronously(fullRequest, measure, cb); + + if (!ret.isOk() || returnedStatus != V1_0::ErrorStatus::NONE) + { + ALOGE("ArmnnPreparedModel_1_2::BurstExecutorWithCache::error executing"); } + return std::make_tuple(returnedStatus, std::move(returnedOutputShapes), returnedTiming); } - catch (armnn::Exception& e) + +private: + V1_2::IPreparedModel* const m_PreparedModel; + std::map<int, hidl_memory> m_MemoryCache; +}; + +template<typename HalVersion> +Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst( + const sp<V1_2::IBurstCallback>& callback, + const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel, + const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel, + V1_2::IPreparedModel::configureExecutionBurst_cb cb) +{ + ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst"); + const std::shared_ptr<ArmnnBurstExecutorWithCache> executorWithCache = + std::make_shared<ArmnnBurstExecutorWithCache>(this); + const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback, + requestChannel, + resultChannel, + executorWithCache); + + if (burst == nullptr) { - ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what()); - callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); - return V1_0::ErrorStatus::GENERAL_FAILURE; + cb(V1_0::ErrorStatus::GENERAL_FAILURE, {}); } - catch (std::exception& e) + else { - ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what()); - callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute"); - return V1_0::ErrorStatus::GENERAL_FAILURE; + cb(V1_0::ErrorStatus::NONE, burst); } - - ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg"); - // post the request for asynchronous execution - ArmnnCallback_1_2 armnnCb; - armnnCb.callback = callback; - armnnCb.measureTiming = measureTiming; - armnnCb.driverStart = driverStart; - m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb); - ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg"); - return V1_0::ErrorStatus::NONE; + return Void(); } + + #ifdef ARMNN_ANDROID_NN_V1_2 template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>; +template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>( + std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + armnn::InputTensors& pInputTensors, + armnn::OutputTensors& pOutputTensors, + CallbackContext_1_2 cb); #endif } // namespace armnn_driver diff --git a/ArmnnPreparedModel_1_2.hpp b/ArmnnPreparedModel_1_2.hpp index f609ef7e..e68614a0 100644 --- a/ArmnnPreparedModel_1_2.hpp +++ b/ArmnnPreparedModel_1_2.hpp @@ -19,18 +19,21 @@ namespace armnn_driver { -typedef std::function<void(::android::hardware::neuralnetworks::V1_0::ErrorStatus status, - std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes, - const ::android::hardware::neuralnetworks::V1_2::Timing& timing, - std::string callingFunction)> armnnExecuteCallback_1_2; +using CallbackAsync_1_2 = std::function< + void(V1_0::ErrorStatus errorStatus, + std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes, + const ::android::hardware::neuralnetworks::V1_2::Timing& timing, + std::string callingFunction)>; -struct ArmnnCallback_1_2 +struct ExecutionContext_1_2 { - armnnExecuteCallback_1_2 callback; + ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings = + ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO; TimePoint driverStart; - MeasureTiming measureTiming; }; +using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>; + template <typename HalVersion> class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel { @@ -62,19 +65,38 @@ public: configureExecutionBurst_cb cb) override; /// execute the graph prepared from the request - void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, - std::shared_ptr<armnn::InputTensors>& pInputTensors, - std::shared_ptr<armnn::OutputTensors>& pOutputTensors, - ArmnnCallback_1_2 callbackDescriptor); + template<typename CallbackContext> + bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, + armnn::InputTensors& inputTensors, + armnn::OutputTensors& outputTensors, + CallbackContext callback); /// Executes this model with dummy inputs (e.g. all zeroes). /// \return false on failure, otherwise true bool ExecuteWithDummyInputs(); private: - Return <V1_0::ErrorStatus> Execute(const V1_0::Request& request, - MeasureTiming measureTiming, - armnnExecuteCallback_1_2 callback); + Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request, + MeasureTiming measureTiming, + CallbackAsync_1_2 callback); + + Return<V1_0::ErrorStatus> PrepareMemoryForInputs( + armnn::InputTensors& inputs, + const V1_0::Request& request, + const std::vector<android::nn::RunTimePoolInfo>& memPools); + + Return<V1_0::ErrorStatus> PrepareMemoryForOutputs( + armnn::OutputTensors& outputs, + std::vector<OutputShape> &outputShapes, + const V1_0::Request& request, + const std::vector<android::nn::RunTimePoolInfo>& memPools); + + Return <V1_0::ErrorStatus> PrepareMemoryForIO( + armnn::InputTensors& inputs, + armnn::OutputTensors& outputs, + std::vector<android::nn::RunTimePoolInfo>& memPools, + const V1_0::Request& request, + CallbackAsync_1_2 callback); template <typename TensorBindingCollection> void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); @@ -84,7 +106,9 @@ private: V1_2::Model m_Model; // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads // It is specific to this class, so it is declared as static here - static RequestThread<ArmnnPreparedModel_1_2, HalVersion, ArmnnCallback_1_2> m_RequestThread; + static RequestThread<ArmnnPreparedModel_1_2, + HalVersion, + CallbackContext_1_2> m_RequestThread; uint32_t m_RequestCount; const std::string& m_RequestInputsAndOutputsDumpDir; const bool m_GpuProfilingEnabled; diff --git a/RequestThread.cpp b/RequestThread.cpp index 052c5c11..22a3ac37 100644 --- a/RequestThread.cpp +++ b/RequestThread.cpp @@ -21,15 +21,15 @@ using namespace android; namespace armnn_driver { -template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback> -RequestThread<PreparedModel, HalVersion, Callback>::RequestThread() +template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext> +RequestThread<PreparedModel, HalVersion, CallbackContext>::RequestThread() { ALOGV("RequestThread::RequestThread()"); m_Thread = std::make_unique<std::thread>(&RequestThread::Process, this); } -template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback> -RequestThread<PreparedModel, HalVersion, Callback>::~RequestThread() +template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext> +RequestThread<PreparedModel, HalVersion, CallbackContext>::~RequestThread() { ALOGV("RequestThread::~RequestThread()"); @@ -54,25 +54,25 @@ RequestThread<PreparedModel, HalVersion, Callback>::~RequestThread() catch (const std::exception&) { } // Swallow any exception. } -template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback> -void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(PreparedModel<HalVersion>* model, +template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext> +void RequestThread<PreparedModel, HalVersion, CallbackContext>::PostMsg(PreparedModel<HalVersion>* model, std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools, std::shared_ptr<armnn::InputTensors>& inputTensors, std::shared_ptr<armnn::OutputTensors>& outputTensors, - Callback callback) + CallbackContext callbackContext) { ALOGV("RequestThread::PostMsg(...)"); auto data = std::make_shared<AsyncExecuteData>(model, memPools, inputTensors, outputTensors, - callback); + callbackContext); auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::REQUEST, data); PostMsg(pMsg); } -template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback> -void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) +template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext> +void RequestThread<PreparedModel, HalVersion, CallbackContext>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg) { ALOGV("RequestThread::PostMsg(pMsg)"); // Add a message to the queue and notify the request thread @@ -81,8 +81,8 @@ void RequestThread<PreparedModel, HalVersion, Callback>::PostMsg(std::shared_ptr m_Cv.notify_one(); } -template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback> -void RequestThread<PreparedModel, HalVersion, Callback>::Process() +template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext> +void RequestThread<PreparedModel, HalVersion, CallbackContext>::Process() { ALOGV("RequestThread::Process()"); while (true) @@ -109,9 +109,9 @@ void RequestThread<PreparedModel, HalVersion, Callback>::Process() // invoke the asynchronous execution method PreparedModel<HalVersion>* model = pMsg->data->m_Model; model->ExecuteGraph(pMsg->data->m_MemPools, - pMsg->data->m_InputTensors, - pMsg->data->m_OutputTensors, - pMsg->data->m_Callback); + *(pMsg->data->m_InputTensors), + *(pMsg->data->m_OutputTensors), + pMsg->data->m_CallbackContext); break; } @@ -139,16 +139,16 @@ void RequestThread<PreparedModel, HalVersion, Callback>::Process() /// Class template specializations /// -template class RequestThread<ArmnnPreparedModel, hal_1_0::HalPolicy, ArmnnCallback_1_0>; +template class RequestThread<ArmnnPreparedModel, hal_1_0::HalPolicy, CallbackContext_1_0>; #ifdef ARMNN_ANDROID_NN_V1_1 -template class RequestThread<armnn_driver::ArmnnPreparedModel, hal_1_1::HalPolicy, ArmnnCallback_1_0>; +template class RequestThread<armnn_driver::ArmnnPreparedModel, hal_1_1::HalPolicy, CallbackContext_1_0>; #endif #ifdef ARMNN_ANDROID_NN_V1_2 -template class RequestThread<ArmnnPreparedModel, hal_1_1::HalPolicy, ArmnnCallback_1_0>; -template class RequestThread<ArmnnPreparedModel, hal_1_2::HalPolicy, ArmnnCallback_1_0>; -template class RequestThread<ArmnnPreparedModel_1_2, hal_1_2::HalPolicy, ArmnnCallback_1_2>; +template class RequestThread<ArmnnPreparedModel, hal_1_1::HalPolicy, CallbackContext_1_0>; +template class RequestThread<ArmnnPreparedModel, hal_1_2::HalPolicy, CallbackContext_1_0>; +template class RequestThread<ArmnnPreparedModel_1_2, hal_1_2::HalPolicy, CallbackContext_1_2>; #endif } // namespace armnn_driver diff --git a/RequestThread.hpp b/RequestThread.hpp index 253d104c..79f309a3 100644 --- a/RequestThread.hpp +++ b/RequestThread.hpp @@ -21,7 +21,7 @@ namespace armnn_driver using TimePoint = std::chrono::steady_clock::time_point; static const TimePoint g_Min = std::chrono::steady_clock::time_point::min(); -template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename Callback> +template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext> class RequestThread { public: @@ -41,7 +41,7 @@ public: std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools, std::shared_ptr<armnn::InputTensors>& inputTensors, std::shared_ptr<armnn::OutputTensors>& outputTensors, - Callback callback); + CallbackContext callbackContext); private: RequestThread(const RequestThread&) = delete; @@ -54,12 +54,12 @@ private: std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools, std::shared_ptr<armnn::InputTensors>& inputTensors, std::shared_ptr<armnn::OutputTensors>& outputTensors, - Callback callback) + CallbackContext callbackContext) : m_Model(model) , m_MemPools(memPools) , m_InputTensors(inputTensors) , m_OutputTensors(outputTensors) - , m_Callback(callback) + , m_CallbackContext(callbackContext) { } @@ -67,9 +67,8 @@ private: std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools; std::shared_ptr<armnn::InputTensors> m_InputTensors; std::shared_ptr<armnn::OutputTensors> m_OutputTensors; - Callback m_Callback; + CallbackContext m_CallbackContext; }; - enum class ThreadMsgType { EXIT, // exit the thread |