diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-11-26 10:38:11 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2020-11-26 10:38:11 +0000 |
commit | dea8fb6b96663de5d3df2f9fceb9bd09432fd7aa (patch) | |
tree | d041bf8e9d406c80ce089fc5b8d84b44381332e1 /src/backends/cl | |
parent | f4f150c30d3c34e9f26757ca43e4a2694b882bce (diff) | |
download | armnn-dea8fb6b96663de5d3df2f9fceb9bd09432fd7aa.tar.gz |
IVGCVSW-5481 'Add ClCompileContext to ClWorkloadFactory'
* Introduced CLCompileContext to ClWorkloadFactory
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Ied38f4336210502e5f518b9955ae6a5ba3d242b3
Diffstat (limited to 'src/backends/cl')
-rw-r--r-- | src/backends/cl/ClBackendModelContext.cpp | 29 | ||||
-rw-r--r-- | src/backends/cl/ClBackendModelContext.hpp | 9 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 44 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.hpp | 7 |
4 files changed, 88 insertions, 1 deletions
diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp index 0ef26b64d2..b685bc296c 100644 --- a/src/backends/cl/ClBackendModelContext.cpp +++ b/src/backends/cl/ClBackendModelContext.cpp @@ -17,13 +17,22 @@ bool ParseBool(const armnn::BackendOptions::Var& value, bool defaultValue) return defaultValue; } +std::string ParseFile(const armnn::BackendOptions::Var& value, std::string defaultValue) +{ + if (value.IsString()) + { + return value.AsString(); + } + return defaultValue; +} + } // namespace anonymous namespace armnn { ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) - : m_IsFastMathEnabled(false) + : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false) { if (!modelOptions.empty()) { @@ -33,13 +42,31 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) { m_IsFastMathEnabled |= ParseBool(value, false); } + if (name == "SaveCachedNetwork") + { + m_SaveCachedNetwork |= ParseBool(value, false); + } + if (name == "CachedNetworkFilePath") + { + m_CachedNetworkFilePath = ParseFile(value, ""); + } }); } } +std::string ClBackendModelContext::GetCachedNetworkFilePath() const +{ + return m_CachedNetworkFilePath; +} + bool ClBackendModelContext::IsFastMathEnabled() const { return m_IsFastMathEnabled; } +bool ClBackendModelContext::SaveCachedNetwork() const +{ + return m_SaveCachedNetwork; +} + } // namespace armnn
\ No newline at end of file diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp index 577649aafb..c84cdbbfcf 100644 --- a/src/backends/cl/ClBackendModelContext.hpp +++ b/src/backends/cl/ClBackendModelContext.hpp @@ -6,6 +6,8 @@ #include <armnn/backends/IBackendContext.hpp> +#include<string> + namespace armnn { @@ -19,10 +21,17 @@ class ClBackendModelContext : public IBackendModelContext public: ClBackendModelContext(const ModelOptions& modelOptions); + std::string GetCachedNetworkFilePath() const; + bool IsFastMathEnabled() const; + bool SaveCachedNetwork() const; + private: + std::string m_CachedNetworkFilePath; bool m_IsFastMathEnabled; + bool m_SaveCachedNetwork; + }; } // namespace armnn
\ No newline at end of file diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index c8c1fb71ec..41b779f64a 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -27,6 +27,8 @@ #include <arm_compute/runtime/CL/CLBufferAllocator.h> #include <arm_compute/runtime/CL/CLScheduler.h> +#include <Filesystem.hpp> + namespace armnn { @@ -55,6 +57,23 @@ const BackendId& ClWorkloadFactory::GetBackendId() const return s_Id; } +void ClWorkloadFactory::AfterWorkloadsCreated() +{ + if(m_ModelContextPtr) + { + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + if (modelOptions->SaveCachedNetwork()) + { + // Save map to a filepath provided in ModelOptions + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + /// Saving will be implemented within IVGCVSW-5483 story. + } + } + } +} + template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args> std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, @@ -85,15 +104,40 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptor } } +void ClWorkloadFactory::InitializeCLCompileContext() +{ + // Initialize our m_CLCompileContext using default device and context + cl::Device device = cl::Device::getDefault(); + cl::Context context = cl::Context(device); + + m_CLCompileContext = arm_compute::CLCompileContext(context, device); + + if (m_ModelContextPtr) + { + // Load saved programs if the user has set a filepath + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" + && fs::exists(filePath) + && fs::is_regular_file(filePath) + && !(modelOptions->SaveCachedNetwork())) + { + /// Loading will be implemented within IVGCVSW-5483 story. + } + } +} + ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager) : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{}) { + InitializeCLCompileContext(); } ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager, const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr) : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr) { + InitializeCLCompileContext(); } std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 84eae5076a..c8812cfe1b 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -12,6 +12,8 @@ #include <backendsCommon/WorkloadFactoryBase.hpp> #include <aclCommon/BaseMemoryManager.hpp> +#include <arm_compute/core/CL/CLCompileContext.h> + namespace armnn { @@ -24,6 +26,8 @@ public: ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager, const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr); + void AfterWorkloadsCreated() override; + const BackendId& GetBackendId() const override; static bool IsLayerSupported(const Layer& layer, @@ -254,8 +258,11 @@ private: const WorkloadInfo& info, Args&&... args); + void InitializeCLCompileContext(); + mutable std::shared_ptr<ClMemoryManager> m_MemoryManager; const IBackendInternal::IBackendSpecificModelContextPtr m_ModelContextPtr; + arm_compute::CLCompileContext m_CLCompileContext; }; } // namespace armnn |