diff options
Diffstat (limited to 'src/backends/cl/ClWorkloadFactory.cpp')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index c8c1fb71ec..41b779f64a 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -27,6 +27,8 @@ #include <arm_compute/runtime/CL/CLBufferAllocator.h> #include <arm_compute/runtime/CL/CLScheduler.h> +#include <Filesystem.hpp> + namespace armnn { @@ -55,6 +57,23 @@ const BackendId& ClWorkloadFactory::GetBackendId() const return s_Id; } +void ClWorkloadFactory::AfterWorkloadsCreated() +{ + if(m_ModelContextPtr) + { + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + if (modelOptions->SaveCachedNetwork()) + { + // Save map to a filepath provided in ModelOptions + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) + { + /// Saving will be implemented within IVGCVSW-5483 story. + } + } + } +} + template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args> std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info, @@ -85,15 +104,40 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptor } } +void ClWorkloadFactory::InitializeCLCompileContext() +{ + // Initialize our m_CLCompileContext using default device and context + cl::Device device = cl::Device::getDefault(); + cl::Context context = cl::Context(device); + + m_CLCompileContext = arm_compute::CLCompileContext(context, device); + + if (m_ModelContextPtr) + { + // Load saved programs if the user has set a filepath + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + auto filePath = modelOptions->GetCachedNetworkFilePath(); + if (filePath != "" + && fs::exists(filePath) + && fs::is_regular_file(filePath) + && !(modelOptions->SaveCachedNetwork())) + { + /// Loading will be implemented within IVGCVSW-5483 story. + } + } +} + ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager) : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{}) { + InitializeCLCompileContext(); } ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager, const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr) : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr) { + InitializeCLCompileContext(); } std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, |