diff options
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.hpp | 2 | ||||
-rw-r--r-- | src/backends/cl/ClBackendModelContext.cpp | 29 | ||||
-rw-r--r-- | src/backends/cl/ClBackendModelContext.hpp | 9 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 44 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.hpp | 7 |
5 files changed, 90 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index df08b9a81d..2e813e9945 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -23,6 +23,8 @@ class IWorkloadFactory public: virtual ~IWorkloadFactory() { } + virtual void AfterWorkloadsCreated() {}; + virtual const BackendId& GetBackendId() const = 0; static bool IsLayerSupported(const BackendId& backendId, diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp index 0ef26b64d2..b685bc296c 100644 --- a/src/backends/cl/ClBackendModelContext.cpp +++ b/src/backends/cl/ClBackendModelContext.cpp @@ -17,13 +17,22 @@ bool ParseBool(const armnn::BackendOptions::Var& value, bool defaultValue) return defaultValue; } +std::string ParseFile(const armnn::BackendOptions::Var& value, std::string defaultValue) +{ + if (value.IsString()) + { + return value.AsString(); + } + return defaultValue; +} + } // namespace anonymous namespace armnn { ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) - : m_IsFastMathEnabled(false) + : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false) { if (!modelOptions.empty()) { @@ -33,13 +42,31 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions) { m_IsFastMathEnabled |= ParseBool(value, false); } + if (name == "SaveCachedNetwork") + { + m_SaveCachedNetwork |= ParseBool(value, false); + } + if (name == "CachedNetworkFilePath") + { + m_CachedNetworkFilePath = ParseFile(value, ""); + } }); } } +std::string ClBackendModelContext::GetCachedNetworkFilePath() const +{ + return m_CachedNetworkFilePath; +} + bool ClBackendModelContext::IsFastMathEnabled() const { return m_IsFastMathEnabled; } +bool ClBackendModelContext::SaveCachedNetwork() const +{ + return m_SaveCachedNetwork; +} + } // namespace armnn
\ No newline at end of file diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp index 577649aafb..c84cdbbfcf 100644 --- a/src/backends/cl/ClBackendModelContext.hpp +++ b/src/backends/cl/ClBackendModelContext.hpp @@ -6,6 +6,8 @@ #include <armnn/backends/IBackendContext.hpp> +#include<string> + namespace armnn { @@ -19,10 +21,17 @@ class ClBackendModelContext : public IBackendModelContext public: ClBackendModelContext(const ModelOptions& modelOptions); + std::string GetCachedNetworkFilePath() const; + bool IsFastMathEnabled() const; + bool SaveCachedNetwork() const; + private: + std::string m_CachedNetworkFilePath; bool m_IsFastMathEnabled; + bool m_SaveCachedNetwork; + }; } // namespace armnn
\ No newline at end of file diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index c8c1fb71ec..41b779f64a 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -27,6 +27,8 @@ #include <arm_compute/runtime/CL/CLBufferAllocator.h> #include <arm_compute/runtime/CL/CLScheduler.h> +#include <Filesystem.hpp> + namespace armnn { @@ -55,6 +57,23 @@ const BackendId& ClWorkloadFactory::GetBackendId() const return s_Id; } +void ClWorkloadFactory::AfterWorkloadsCreated() +{ + if(m_ModelContextPtr) + { + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + if (modelOptions->SaveCachedNetwork()) + { + // Save map to a filepath provided in ModelOptions + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + /// Saving will be implemented within IVGCVSW-5483 story. + } + } + } +} + template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args> std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, @@ -85,15 +104,40 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptor } } +void ClWorkloadFactory::InitializeCLCompileContext() +{ + // Initialize our m_CLCompileContext using default device and context + cl::Device device = cl::Device::getDefault(); + cl::Context context = cl::Context(device); + + m_CLCompileContext = arm_compute::CLCompileContext(context, device); + + if (m_ModelContextPtr) + { + // Load saved programs if the user has set a filepath + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" + && fs::exists(filePath) + && fs::is_regular_file(filePath) + && !(modelOptions->SaveCachedNetwork())) + { + /// Loading will be implemented within IVGCVSW-5483 story. + } + } +} + ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager) : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{}) { + InitializeCLCompileContext(); } ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager, const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr) : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr) { + InitializeCLCompileContext(); } std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 84eae5076a..c8812cfe1b 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -12,6 +12,8 @@ #include <backendsCommon/WorkloadFactoryBase.hpp> #include <aclCommon/BaseMemoryManager.hpp> +#include <arm_compute/core/CL/CLCompileContext.h> + namespace armnn { @@ -24,6 +26,8 @@ public: ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager, const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr); + void AfterWorkloadsCreated() override; + const BackendId& GetBackendId() const override; static bool IsLayerSupported(const Layer& layer, @@ -254,8 +258,11 @@ private: const WorkloadInfo& info, Args&&... args); + void InitializeCLCompileContext(); + mutable std::shared_ptr<ClMemoryManager> m_MemoryManager; const IBackendInternal::IBackendSpecificModelContextPtr m_ModelContextPtr; + arm_compute::CLCompileContext m_CLCompileContext; }; } // namespace armnn |