From dea8fb6b96663de5d3df2f9fceb9bd09432fd7aa Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Thu, 26 Nov 2020 10:38:11 +0000 Subject: IVGCVSW-5481 'Add ClCompileContext to ClWorkloadFactory' * Introduced CLCompileContext to ClWorkloadFactory Signed-off-by: Sadik Armagan Change-Id: Ied38f4336210502e5f518b9955ae6a5ba3d242b3 --- src/backends/backendsCommon/WorkloadFactory.hpp | 2 ++ src/backends/cl/ClBackendModelContext.cpp | 29 +++++++++++++++- src/backends/cl/ClBackendModelContext.hpp | 9 +++++ src/backends/cl/ClWorkloadFactory.cpp | 44 +++++++++++++++++++++++++ src/backends/cl/ClWorkloadFactory.hpp | 7 ++++ 5 files changed, 90 insertions(+), 1 deletion(-) (limited to 'src/backends') diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index df08b9a81d..2e813e9945 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -23,6 +23,8 @@ class IWorkloadFactory public: virtual ~IWorkloadFactory() { } + virtual void AfterWorkloadsCreated() {}; + virtual const BackendId& GetBackendId() const = 0; static bool IsLayerSupported(const BackendId& backendId, diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp index 0ef26b64d2..b685bc296c 100644 --- a/src/backends/cl/ClBackendModelContext.cpp +++ b/src/backends/cl/ClBackendModelContext.cpp @@ -17,13 +17,22 @@ bool ParseBool(const armnn::BackendOptions::Var& value, bool defaultValue) return defaultValue; } +std::string ParseFile(const armnn::BackendOptions::Var& value, std::string defaultValue) +{ + if (value.IsString()) + { + return value.AsString(); + } + return defaultValue; +} + } // namespace anonymous namespace armnn { ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) - : m_IsFastMathEnabled(false) + : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false) { if (!modelOptions.empty()) { @@ -33,13 +42,31 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) { m_IsFastMathEnabled |= ParseBool(value, false); } + if (name == "SaveCachedNetwork") + { + m_SaveCachedNetwork |= ParseBool(value, false); + } + if (name == "CachedNetworkFilePath") + { + m_CachedNetworkFilePath = ParseFile(value, ""); + } }); } } +std::string ClBackendModelContext::GetCachedNetworkFilePath() const +{ + return m_CachedNetworkFilePath; +} + bool ClBackendModelContext::IsFastMathEnabled() const { return m_IsFastMathEnabled; } +bool ClBackendModelContext::SaveCachedNetwork() const +{ + return m_SaveCachedNetwork; +} + } // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp index 577649aafb..c84cdbbfcf 100644 --- a/src/backends/cl/ClBackendModelContext.hpp +++ b/src/backends/cl/ClBackendModelContext.hpp @@ -6,6 +6,8 @@ #include +#include + namespace armnn { @@ -19,10 +21,17 @@ class ClBackendModelContext : public IBackendModelContext public: ClBackendModelContext(const ModelOptions& modelOptions); + std::string GetCachedNetworkFilePath() const; + bool IsFastMathEnabled() const; + bool SaveCachedNetwork() const; + private: + std::string m_CachedNetworkFilePath; bool m_IsFastMathEnabled; + bool m_SaveCachedNetwork; + }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index c8c1fb71ec..41b779f64a 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -27,6 +27,8 @@ #include #include +#include + namespace armnn { @@ -55,6 +57,23 @@ const BackendId& ClWorkloadFactory::GetBackendId() const return s_Id; } +void ClWorkloadFactory::AfterWorkloadsCreated() +{ + if(m_ModelContextPtr) + { + auto modelOptions = dynamic_cast(m_ModelContextPtr.get()); + if (modelOptions->SaveCachedNetwork()) + { + // Save map to a filepath provided in ModelOptions + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + /// Saving will be implemented within IVGCVSW-5483 story. + } + } + } +} + template std::unique_ptr ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, @@ -85,15 +104,40 @@ std::unique_ptr ClWorkloadFactory::MakeWorkload(const QueueDescriptor } } +void ClWorkloadFactory::InitializeCLCompileContext() +{ + // Initialize our m_CLCompileContext using default device and context + cl::Device device = cl::Device::getDefault(); + cl::Context context = cl::Context(device); + + m_CLCompileContext = arm_compute::CLCompileContext(context, device); + + if (m_ModelContextPtr) + { + // Load saved programs if the user has set a filepath + auto modelOptions = dynamic_cast(m_ModelContextPtr.get()); + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" + && fs::exists(filePath) + && fs::is_regular_file(filePath) + && !(modelOptions->SaveCachedNetwork())) + { + /// Loading will be implemented within IVGCVSW-5483 story. + } + } +} + ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr& memoryManager) : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{}) { + InitializeCLCompileContext(); } ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr& memoryManager, const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr) : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr) { + InitializeCLCompileContext(); } std::unique_ptr ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 84eae5076a..c8812cfe1b 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -12,6 +12,8 @@ #include #include +#include + namespace armnn { @@ -24,6 +26,8 @@ public: ClWorkloadFactory(const std::shared_ptr& memoryManager, const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr); + void AfterWorkloadsCreated() override; + const BackendId& GetBackendId() const override; static bool IsLayerSupported(const Layer& layer, @@ -254,8 +258,11 @@ private: const WorkloadInfo& info, Args&&... args); + void InitializeCLCompileContext(); + mutable std::shared_ptr m_MemoryManager; const IBackendInternal::IBackendSpecificModelContextPtr m_ModelContextPtr; + arm_compute::CLCompileContext m_CLCompileContext; }; } // namespace armnn -- cgit v1.2.1