diff options
Diffstat (limited to 'tests/ExecuteNetwork')
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index e8d5b1860c..48577c9990 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -445,14 +445,8 @@ int MainImpl(const ExecuteNetworkParams& params, try { ARMNN_LOG(info) << "Asynchronous execution with Arm NN thread pool... \n"; - std::vector<armnn::experimental::IAsyncExecutionCallbackPtr> callbacks; - - // Create callbacks that will be checked post scheduling - for (size_t i = 0; i < params.m_SimultaneousIterations; ++i) - { - // Point to ArmNN example implementation of AsyncExecutionCallback - callbacks.emplace_back(std::make_shared<armnn::experimental::AsyncExecutionCallback>()); - } + armnn::AsyncCallbackManager callbackManager; + std::unordered_map<armnn::InferenceId, std::vector<TContainer>&> inferenceOutputMap; // Declare the latest and earliest inference times here to be used when calculating overall time std::chrono::high_resolution_clock::time_point earliestStartTime; @@ -461,15 +455,19 @@ int MainImpl(const ExecuteNetworkParams& params, // For the asynchronous execution, we are adding a pool of working memory handles (1 per thread) in the // LoadedNetwork with each scheduled inference having a specific priority - for (size_t i = 0; i < callbacks.size(); ++i) + for (size_t i = 0; i < params.m_SimultaneousIterations; ++i) { - model.RunAsync(inputs[i], outputs[i], callbacks[i]); + std::shared_ptr<armnn::AsyncExecutionCallback> cb = callbackManager.GetNewCallback(); + inferenceOutputMap.insert({cb->GetInferenceId(), outputs[i]}); + model.RunAsync(inputs[i], outputs[i], cb); } // Check the results unsigned int j = 0; - for (armnn::experimental::IAsyncExecutionCallbackPtr cb : callbacks) + for (size_t iteration = 0; iteration < params.m_SimultaneousIterations; ++iteration) { + auto cb = callbackManager.GetNotifiedCallback(); + // Get the results auto endTime = time_point_cast<std::chrono::milliseconds>(cb->GetEndTime()); auto startTime = time_point_cast<std::chrono::milliseconds>(cb->GetStartTime()); @@ -507,7 +505,7 @@ int MainImpl(const ExecuteNetworkParams& params, infoOut, outputTensorFile, params.m_DequantizeOutput); - mapbox::util::apply_visitor(printer, outputs[j][i]); + mapbox::util::apply_visitor(printer, inferenceOutputMap.at(cb->GetInferenceId())[i]); } ARMNN_LOG(info) << "\nInference time: " << std::setprecision(2) @@ -549,7 +547,7 @@ int MainImpl(const ExecuteNetworkParams& params, try { ARMNN_LOG(info) << "Asynchronous Execution with std::launch:async... \n"; - std::vector<std::future<std::tuple<armnn::profiling::ProfilingGuid, + std::vector<std::future<std::tuple<unsigned int, std::chrono::duration<double, std::milli>>>> inferenceResults; inferenceResults.reserve(params.m_SimultaneousIterations); @@ -567,9 +565,10 @@ int MainImpl(const ExecuteNetworkParams& params, for (unsigned int i = 0; i < params.m_SimultaneousIterations; ++i) { armnn::experimental::IWorkingMemHandle& workingMemHandleRef = *workingMemHandles[i].get(); + inferenceResults.push_back(std::async( std::launch::async, [&model, &workingMemHandleRef, &inputs, &outputs, i]() { - return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i]); + return model.RunAsync(workingMemHandleRef, inputs[i], outputs[i], i); } )); } |