From 1b61be517387a20cd869e30587de2140b6d2252d Mon Sep 17 00:00:00 2001 From: David Beck Date: Thu, 8 Nov 2018 09:19:14 +0000 Subject: IVGCVSW-2056+IVGCVSW-2064 : move ClContextControl to the ClBackend * add IBackendContext interface * add ClBackendContext implementation Change-Id: I13e4d12b73d4c7775069587675276f7cee7d630b --- src/armnn/Runtime.cpp | 82 +++++++++------ src/armnn/Runtime.hpp | 5 +- src/backends/backendsCommon/BackendRegistry.hpp | 7 +- src/backends/backendsCommon/CMakeLists.txt | 1 + src/backends/backendsCommon/IBackendContext.hpp | 32 ++++++ src/backends/backendsCommon/IBackendInternal.hpp | 5 + .../backendsCommon/test/BackendRegistryTests.cpp | 1 + src/backends/cl/CMakeLists.txt | 2 + src/backends/cl/ClBackend.cpp | 8 ++ src/backends/cl/ClBackend.hpp | 4 +- src/backends/cl/ClBackendContext.cpp | 113 +++++++++++++++++++++ src/backends/cl/ClBackendContext.hpp | 36 +++++++ src/backends/cl/backend.mk | 1 + src/backends/neon/NeonBackend.hpp | 6 ++ src/backends/reference/RefBackend.hpp | 8 +- tests/CMakeLists.txt | 12 +-- tests/InferenceModel.hpp | 2 +- 17 files changed, 275 insertions(+), 50 deletions(-) create mode 100644 src/backends/backendsCommon/IBackendContext.hpp create mode 100644 src/backends/cl/ClBackendContext.cpp create mode 100644 src/backends/cl/ClBackendContext.hpp diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 37e25a7fb6..09be92c709 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -6,15 +6,10 @@ #include #include +#include #include -#ifdef ARMCOMPUTECL_ENABLED -#include -#include -#include -#endif - #include #include @@ -55,6 +50,14 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut, std::string & errorMessage) { IOptimizedNetwork* rawNetwork = inNetwork.release(); + + networkIdOut = GenerateNetworkId(); + + for (auto&& context : m_BackendContexts) + { + context.second->BeforeLoadNetwork(networkIdOut); + } + unique_ptr loadedNetwork = LoadedNetwork::MakeLoadedNetwork( std::unique_ptr(boost::polymorphic_downcast(rawNetwork)), errorMessage); @@ -64,8 +67,6 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut, return Status::Failure; } - networkIdOut = GenerateNetworkId(); - { std::lock_guard lockGuard(m_Mutex); @@ -73,28 +74,28 @@ Status Runtime::LoadNetwork(NetworkId& networkIdOut, m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork); } + for (auto&& context : m_BackendContexts) + { + context.second->AfterLoadNetwork(networkIdOut); + } + return Status::Success; } Status Runtime::UnloadNetwork(NetworkId networkId) { -#ifdef ARMCOMPUTECL_ENABLED - if (arm_compute::CLScheduler::get().context()() != NULL) + bool unloadOk = true; + for (auto&& context : m_BackendContexts) { - // Waits for all queued CL requests to finish before unloading the network they may be using. - try - { - // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error. - arm_compute::CLScheduler::get().sync(); - } - catch (const cl::Error&) - { - BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for " - "the queued CL requests to finish"; - return Status::Failure; - } + unloadOk &= context.second->BeforeUnloadNetwork(networkId); + } + + if (!unloadOk) + { + BOOST_LOG_TRIVIAL(warning) << "Runtime::UnloadNetwork(): failed to unload " + "network with ID:" << networkId << " because BeforeUnloadNetwork failed"; + return Status::Failure; } -#endif { std::lock_guard lockGuard(m_Mutex); @@ -104,14 +105,11 @@ Status Runtime::UnloadNetwork(NetworkId networkId) BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!"; return Status::Failure; } + } -#ifdef ARMCOMPUTECL_ENABLED - if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty()) - { - // There are no loaded networks left, so clear the CL cache to free up memory - m_ClContextControl.ClearClCache(); - } -#endif + for (auto&& context : m_BackendContexts) + { + context.second->AfterUnloadNetwork(networkId); } BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId; @@ -131,12 +129,30 @@ const std::shared_ptr Runtime::GetProfiler(NetworkId networkId) const } Runtime::Runtime(const CreationOptions& options) - : m_ClContextControl(options.m_GpuAccTunedParameters.get(), - options.m_EnableGpuProfiling) - , m_NetworkIdCounter(0) + : m_NetworkIdCounter(0) , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()} { BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n"; + + for (const auto& id : BackendRegistryInstance().GetBackendIds()) + { + // Store backend contexts for the supported ones + if (m_DeviceSpec.GetSupportedBackends().count(id) > 0) + { + auto factoryFun = BackendRegistryInstance().GetFactory(id); + auto backend = factoryFun(); + BOOST_ASSERT(backend.get() != nullptr); + + auto context = backend->CreateBackendContext(options); + + // backends are allowed to return nullptrs if they + // don't wish to create a backend specific context + if (context) + { + m_BackendContexts.emplace(std::make_pair(id, std::move(context))); + } + } + } } Runtime::~Runtime() diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index e4d4d4ddb9..a3f4a3930b 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -11,8 +11,6 @@ #include #include -#include - #include #include @@ -89,8 +87,7 @@ private: mutable std::mutex m_Mutex; std::unordered_map> m_LoadedNetworks; - - ClContextControl m_ClContextControl; + std::unordered_map m_BackendContexts; int m_NetworkIdCounter; diff --git a/src/backends/backendsCommon/BackendRegistry.hpp b/src/backends/backendsCommon/BackendRegistry.hpp index 145da8819c..4b20cacbe0 100644 --- a/src/backends/backendsCommon/BackendRegistry.hpp +++ b/src/backends/backendsCommon/BackendRegistry.hpp @@ -4,14 +4,13 @@ // #pragma once -#include "IBackendInternal.hpp" #include "RegistryCommon.hpp" - #include namespace armnn { - +class IBackendInternal; +using IBackendInternalUniquePtr = std::unique_ptr; using BackendRegistry = RegistryCommon; BackendRegistry& BackendRegistryInstance(); @@ -22,4 +21,4 @@ struct RegisteredTypeName static const char * Name() { return "IBackend"; } }; -} // namespace armnn +} // namespace armnn \ No newline at end of file diff --git a/src/backends/backendsCommon/CMakeLists.txt b/src/backends/backendsCommon/CMakeLists.txt index 9dd9b92fe3..f4ab45f8b4 100644 --- a/src/backends/backendsCommon/CMakeLists.txt +++ b/src/backends/backendsCommon/CMakeLists.txt @@ -10,6 +10,7 @@ list(APPEND armnnBackendsCommon_sources CpuTensorHandleFwd.hpp CpuTensorHandle.hpp IBackendInternal.hpp + IBackendContext.hpp ILayerSupport.cpp ITensorHandle.hpp LayerSupportRegistry.cpp diff --git a/src/backends/backendsCommon/IBackendContext.hpp b/src/backends/backendsCommon/IBackendContext.hpp new file mode 100644 index 0000000000..de9824956f --- /dev/null +++ b/src/backends/backendsCommon/IBackendContext.hpp @@ -0,0 +1,32 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include +#include + +namespace armnn +{ + +class IBackendContext +{ +protected: + IBackendContext(const IRuntime::CreationOptions&) {} + +public: + // Before and after Load network events + virtual bool BeforeLoadNetwork(NetworkId networkId) = 0; + virtual bool AfterLoadNetwork(NetworkId networkId) = 0; + + // Before and after Unload network events + virtual bool BeforeUnloadNetwork(NetworkId networkId) = 0; + virtual bool AfterUnloadNetwork(NetworkId networkId) = 0; + + virtual ~IBackendContext() {} +}; + +using IBackendContextUniquePtr = std::unique_ptr; + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp index 7e44dbd676..fede366475 100644 --- a/src/backends/backendsCommon/IBackendInternal.hpp +++ b/src/backends/backendsCommon/IBackendInternal.hpp @@ -5,10 +5,12 @@ #pragma once #include +#include namespace armnn { class IWorkloadFactory; +class IBackendContext; class IBackendInternal : public IBackend { @@ -23,7 +25,10 @@ public: ~IBackendInternal() override = default; using IWorkloadFactoryPtr = std::unique_ptr; + using IBackendContextPtr = std::unique_ptr; + virtual IWorkloadFactoryPtr CreateWorkloadFactory() const = 0; + virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const = 0; }; using IBackendInternalUniquePtr = std::unique_ptr; diff --git a/src/backends/backendsCommon/test/BackendRegistryTests.cpp b/src/backends/backendsCommon/test/BackendRegistryTests.cpp index 0bc655be09..26175e015f 100644 --- a/src/backends/backendsCommon/test/BackendRegistryTests.cpp +++ b/src/backends/backendsCommon/test/BackendRegistryTests.cpp @@ -6,6 +6,7 @@ #include #include +#include #include diff --git a/src/backends/cl/CMakeLists.txt b/src/backends/cl/CMakeLists.txt index d751854c92..dd2a4a12b1 100644 --- a/src/backends/cl/CMakeLists.txt +++ b/src/backends/cl/CMakeLists.txt @@ -7,6 +7,8 @@ if(ARMCOMPUTECL) list(APPEND armnnClBackend_sources ClBackend.cpp ClBackend.hpp + ClBackendContext.cpp + ClBackendContext.hpp ClBackendId.hpp ClContextControl.cpp ClContextControl.hpp diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index b1857a3678..c07fa66457 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -6,6 +6,7 @@ #include "ClBackend.hpp" #include "ClBackendId.hpp" #include "ClWorkloadFactory.hpp" +#include "ClBackendContext.hpp" #include @@ -38,4 +39,11 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory() const return std::make_unique(); } +IBackendInternal::IBackendContextPtr +ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const +{ + return IBackendContextPtr{new ClBackendContext{options}}; +} + + } // namespace armnn diff --git a/src/backends/cl/ClBackend.hpp b/src/backends/cl/ClBackend.hpp index 223aeb3095..f8a6253c22 100644 --- a/src/backends/cl/ClBackend.hpp +++ b/src/backends/cl/ClBackend.hpp @@ -4,6 +4,7 @@ // #pragma once +#include #include namespace armnn @@ -18,7 +19,8 @@ public: static const BackendId& GetIdStatic(); const BackendId& GetId() const override { return GetIdStatic(); } - IWorkloadFactoryPtr CreateWorkloadFactory() const override; + IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override; + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClBackendContext.cpp b/src/backends/cl/ClBackendContext.cpp new file mode 100644 index 0000000000..a2c1b87359 --- /dev/null +++ b/src/backends/cl/ClBackendContext.cpp @@ -0,0 +1,113 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClBackendContext.hpp" +#include "ClContextControl.hpp" + +#include + +#ifdef ARMCOMPUTECL_ENABLED +#include +#include +#include +#endif + +namespace armnn +{ + +struct ClBackendContext::ClContextControlWrapper +{ + ClContextControlWrapper(IGpuAccTunedParameters* clTunedParameters, + bool profilingEnabled) + : m_ClContextControl(clTunedParameters, profilingEnabled) + {} + + bool Sync() + { +#ifdef ARMCOMPUTECL_ENABLED + if (arm_compute::CLScheduler::get().context()() != NULL) + { + // Waits for all queued CL requests to finish before unloading the network they may be using. + try + { + // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error. + arm_compute::CLScheduler::get().sync(); + } + catch (const cl::Error&) + { + BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for " + "the queued CL requests to finish"; + return false; + } + } +#endif + return true; + } + + void ClearClCache() + { +#ifdef ARMCOMPUTECL_ENABLED + if (arm_compute::CLScheduler::get().context()() != NULL) + { + // There are no loaded networks left, so clear the CL cache to free up memory + m_ClContextControl.ClearClCache(); + } +#endif + } + + + ClContextControl m_ClContextControl; +}; + + +ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options) + : IBackendContext(options) + , m_ClContextControlWrapper( + std::make_unique(options.m_GpuAccTunedParameters.get(), + options.m_EnableGpuProfiling)) +{ +} + +bool ClBackendContext::BeforeLoadNetwork(NetworkId) +{ + return true; +} + +bool ClBackendContext::AfterLoadNetwork(NetworkId networkId) +{ + { + std::lock_guard lockGuard(m_Mutex); + m_NetworkIds.insert(networkId); + } + return true; +} + +bool ClBackendContext::BeforeUnloadNetwork(NetworkId) +{ + return m_ClContextControlWrapper->Sync(); +} + +bool ClBackendContext::AfterUnloadNetwork(NetworkId networkId) +{ + bool clearCache = false; + { + std::lock_guard lockGuard(m_Mutex); + m_NetworkIds.erase(networkId); + clearCache = m_NetworkIds.empty(); + } + + if (clearCache) + { + m_ClContextControlWrapper->ClearClCache(); + } + + return true; +} + +ClBackendContext::~ClBackendContext() +{ +} + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClBackendContext.hpp b/src/backends/cl/ClBackendContext.hpp new file mode 100644 index 0000000000..24497c2249 --- /dev/null +++ b/src/backends/cl/ClBackendContext.hpp @@ -0,0 +1,36 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include +#include +#include + +namespace armnn +{ + +class ClBackendContext : public IBackendContext +{ +public: + ClBackendContext(const IRuntime::CreationOptions& options); + + bool BeforeLoadNetwork(NetworkId networkId) override; + bool AfterLoadNetwork(NetworkId networkId) override; + + bool BeforeUnloadNetwork(NetworkId networkId) override; + bool AfterUnloadNetwork(NetworkId networkId) override; + + ~ClBackendContext() override; + +private: + std::mutex m_Mutex; + struct ClContextControlWrapper; + std::unique_ptr m_ClContextControlWrapper; + + std::unordered_set m_NetworkIds; + +}; + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index 97df8e4903..fd7ea80f33 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -9,6 +9,7 @@ BACKEND_SOURCES := \ ClBackend.cpp \ + ClBackendContext.cpp \ ClContextControl.cpp \ ClLayerSupport.cpp \ ClWorkloadFactory.cpp \ diff --git a/src/backends/neon/NeonBackend.hpp b/src/backends/neon/NeonBackend.hpp index b8bbd781a4..9ee8b238b3 100644 --- a/src/backends/neon/NeonBackend.hpp +++ b/src/backends/neon/NeonBackend.hpp @@ -4,6 +4,7 @@ // #pragma once +#include #include namespace armnn @@ -19,6 +20,11 @@ public: const BackendId& GetId() const override { return GetIdStatic(); } IWorkloadFactoryPtr CreateWorkloadFactory() const override; + + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override + { + return IBackendContextPtr{}; + } }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/RefBackend.hpp b/src/backends/reference/RefBackend.hpp index 48a9d529d5..025a4826b2 100644 --- a/src/backends/reference/RefBackend.hpp +++ b/src/backends/reference/RefBackend.hpp @@ -4,6 +4,7 @@ // #pragma once +#include #include namespace armnn @@ -18,7 +19,12 @@ public: static const BackendId& GetIdStatic(); const BackendId& GetId() const override { return GetIdStatic(); } - IWorkloadFactoryPtr CreateWorkloadFactory() const override; + IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override; + + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override + { + return IBackendContextPtr{}; + } }; } // namespace armnn \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d6475c263b..981553702e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,13 +14,13 @@ set(inference_test_sources InferenceTestImage.cpp) add_library_ex(inferenceTest STATIC ${inference_test_sources}) target_include_directories(inferenceTest PRIVATE ../src/armnnUtils) -target_include_directories(inferenceTest PRIVATE ../src) +target_include_directories(inferenceTest PRIVATE ../src/backends) if(BUILD_CAFFE_PARSER) macro(CaffeParserTest testName sources) add_executable_ex(${testName} ${sources}) target_include_directories(${testName} PRIVATE ../src/armnnUtils) - target_include_directories(${testName} PRIVATE ../src) + target_include_directories(${testName} PRIVATE ../src/backends) set_target_properties(${testName} PROPERTIES COMPILE_FLAGS "${CAFFE_PARSER_TEST_ADDITIONAL_COMPILE_FLAGS}") target_link_libraries(${testName} inferenceTest) @@ -91,7 +91,7 @@ if(BUILD_TF_PARSER) macro(TfParserTest testName sources) add_executable_ex(${testName} ${sources}) target_include_directories(${testName} PRIVATE ../src/armnnUtils) - target_include_directories(${testName} PRIVATE ../src) + target_include_directories(${testName} PRIVATE ../src/backends) target_link_libraries(${testName} inferenceTest) target_link_libraries(${testName} armnnTfParser) @@ -142,7 +142,7 @@ if (BUILD_TF_LITE_PARSER) macro(TfLiteParserTest testName sources) add_executable_ex(${testName} ${sources}) target_include_directories(${testName} PRIVATE ../src/armnnUtils) - target_include_directories(${testName} PRIVATE ../src) + target_include_directories(${testName} PRIVATE ../src/backends) target_link_libraries(${testName} inferenceTest) target_link_libraries(${testName} armnnTfLiteParser) @@ -175,7 +175,7 @@ if (BUILD_ONNX_PARSER) macro(OnnxParserTest testName sources) add_executable_ex(${testName} ${sources}) target_include_directories(${testName} PRIVATE ../src/armnnUtils) - target_include_directories(${testName} PRIVATE ../src) + target_include_directories(${testName} PRIVATE ../src/backends) target_link_libraries(${testName} inferenceTest) target_link_libraries(${testName} armnnOnnxParser) @@ -211,7 +211,7 @@ if (BUILD_CAFFE_PARSER OR BUILD_TF_PARSER OR BUILD_TF_LITE_PARSER OR BUILD_ONNX_ add_executable_ex(ExecuteNetwork ${ExecuteNetwork_sources}) target_include_directories(ExecuteNetwork PRIVATE ../src/armnn) target_include_directories(ExecuteNetwork PRIVATE ../src/armnnUtils) - target_include_directories(ExecuteNetwork PRIVATE ../src) + target_include_directories(ExecuteNetwork PRIVATE ../src/backends) if (BUILD_CAFFE_PARSER) target_link_libraries(ExecuteNetwork armnnCaffeParser) diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 9ca7dfd683..5fefd05619 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -14,7 +14,7 @@ #include #endif -#include +#include #include #include -- cgit v1.2.1