aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-11-26 10:38:11 +0000
committerSadik Armagan <sadik.armagan@arm.com>2020-11-26 10:38:11 +0000
commitdea8fb6b96663de5d3df2f9fceb9bd09432fd7aa (patch)
treed041bf8e9d406c80ce089fc5b8d84b44381332e1 /src/backends
parentf4f150c30d3c34e9f26757ca43e4a2694b882bce (diff)
downloadarmnn-dea8fb6b96663de5d3df2f9fceb9bd09432fd7aa.tar.gz
IVGCVSW-5481 'Add ClCompileContext to ClWorkloadFactory'
* Introduced CLCompileContext to ClWorkloadFactory Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Ied38f4336210502e5f518b9955ae6a5ba3d242b3
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp2
-rw-r--r--src/backends/cl/ClBackendModelContext.cpp29
-rw-r--r--src/backends/cl/ClBackendModelContext.hpp9
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp44
-rw-r--r--src/backends/cl/ClWorkloadFactory.hpp7
5 files changed, 90 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index df08b9a81d..2e813e9945 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -23,6 +23,8 @@ class IWorkloadFactory
public:
virtual ~IWorkloadFactory() { }
+ virtual void AfterWorkloadsCreated() {};
+
virtual const BackendId& GetBackendId() const = 0;
static bool IsLayerSupported(const BackendId& backendId,
diff --git a/src/backends/cl/ClBackendModelContext.cpp b/src/backends/cl/ClBackendModelContext.cpp
index 0ef26b64d2..b685bc296c 100644
--- a/src/backends/cl/ClBackendModelContext.cpp
+++ b/src/backends/cl/ClBackendModelContext.cpp
@@ -17,13 +17,22 @@ bool ParseBool(const armnn::BackendOptions::Var& value, bool defaultValue)
return defaultValue;
}
+std::string ParseFile(const armnn::BackendOptions::Var& value, std::string defaultValue)
+{
+ if (value.IsString())
+ {
+ return value.AsString();
+ }
+ return defaultValue;
+}
+
} // namespace anonymous
namespace armnn
{
ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions)
- : m_IsFastMathEnabled(false)
+ : m_CachedNetworkFilePath(""), m_IsFastMathEnabled(false), m_SaveCachedNetwork(false)
{
if (!modelOptions.empty())
{
@@ -33,13 +42,31 @@ ClBackendModelContext::ClBackendModelContext(const ModelOptions& modelOptions)
{
m_IsFastMathEnabled |= ParseBool(value, false);
}
+ if (name == "SaveCachedNetwork")
+ {
+ m_SaveCachedNetwork |= ParseBool(value, false);
+ }
+ if (name == "CachedNetworkFilePath")
+ {
+ m_CachedNetworkFilePath = ParseFile(value, "");
+ }
});
}
}
+std::string ClBackendModelContext::GetCachedNetworkFilePath() const
+{
+ return m_CachedNetworkFilePath;
+}
+
bool ClBackendModelContext::IsFastMathEnabled() const
{
return m_IsFastMathEnabled;
}
+bool ClBackendModelContext::SaveCachedNetwork() const
+{
+ return m_SaveCachedNetwork;
+}
+
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/cl/ClBackendModelContext.hpp b/src/backends/cl/ClBackendModelContext.hpp
index 577649aafb..c84cdbbfcf 100644
--- a/src/backends/cl/ClBackendModelContext.hpp
+++ b/src/backends/cl/ClBackendModelContext.hpp
@@ -6,6 +6,8 @@
#include <armnn/backends/IBackendContext.hpp>
+#include<string>
+
namespace armnn
{
@@ -19,10 +21,17 @@ class ClBackendModelContext : public IBackendModelContext
public:
ClBackendModelContext(const ModelOptions& modelOptions);
+ std::string GetCachedNetworkFilePath() const;
+
bool IsFastMathEnabled() const;
+ bool SaveCachedNetwork() const;
+
private:
+ std::string m_CachedNetworkFilePath;
bool m_IsFastMathEnabled;
+ bool m_SaveCachedNetwork;
+
};
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index c8c1fb71ec..41b779f64a 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -27,6 +27,8 @@
#include <arm_compute/runtime/CL/CLBufferAllocator.h>
#include <arm_compute/runtime/CL/CLScheduler.h>
+#include <Filesystem.hpp>
+
namespace armnn
{
@@ -55,6 +57,23 @@ const BackendId& ClWorkloadFactory::GetBackendId() const
return s_Id;
}
+void ClWorkloadFactory::AfterWorkloadsCreated()
+{
+ if(m_ModelContextPtr)
+ {
+ auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
+ if (modelOptions->SaveCachedNetwork())
+ {
+ // Save map to a filepath provided in ModelOptions
+ auto filePath = modelOptions->GetCachedNetworkFilePath();
+ if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
+ {
+ /// Saving will be implemented within IVGCVSW-5483 story.
+ }
+ }
+ }
+}
+
template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
const WorkloadInfo& info,
@@ -85,15 +104,40 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptor
}
}
+void ClWorkloadFactory::InitializeCLCompileContext()
+{
+ // Initialize our m_CLCompileContext using default device and context
+ cl::Device device = cl::Device::getDefault();
+ cl::Context context = cl::Context(device);
+
+ m_CLCompileContext = arm_compute::CLCompileContext(context, device);
+
+ if (m_ModelContextPtr)
+ {
+ // Load saved programs if the user has set a filepath
+ auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get());
+ auto filePath = modelOptions->GetCachedNetworkFilePath();
+ if (filePath != ""
+ && fs::exists(filePath)
+ && fs::is_regular_file(filePath)
+ && !(modelOptions->SaveCachedNetwork()))
+ {
+ /// Loading will be implemented within IVGCVSW-5483 story.
+ }
+ }
+}
+
ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager)
: m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{})
{
+ InitializeCLCompileContext();
}
ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager,
const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr)
: m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr)
{
+ InitializeCLCompileContext();
}
std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp
index 84eae5076a..c8812cfe1b 100644
--- a/src/backends/cl/ClWorkloadFactory.hpp
+++ b/src/backends/cl/ClWorkloadFactory.hpp
@@ -12,6 +12,8 @@
#include <backendsCommon/WorkloadFactoryBase.hpp>
#include <aclCommon/BaseMemoryManager.hpp>
+#include <arm_compute/core/CL/CLCompileContext.h>
+
namespace armnn
{
@@ -24,6 +26,8 @@ public:
ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager,
const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr);
+ void AfterWorkloadsCreated() override;
+
const BackendId& GetBackendId() const override;
static bool IsLayerSupported(const Layer& layer,
@@ -254,8 +258,11 @@ private:
const WorkloadInfo& info,
Args&&... args);
+ void InitializeCLCompileContext();
+
mutable std::shared_ptr<ClMemoryManager> m_MemoryManager;
const IBackendInternal::IBackendSpecificModelContextPtr m_ModelContextPtr;
+ arm_compute::CLCompileContext m_CLCompileContext;
};
} // namespace armnn