aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl')
-rw-r--r--src/backends/cl/ClTensorHandleFactory.cpp18
-rw-r--r--src/backends/cl/ClTensorHandleFactory.hpp6
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp6
-rw-r--r--src/backends/cl/ClWorkloadFactory.hpp6
4 files changed, 24 insertions, 12 deletions
diff --git a/src/backends/cl/ClTensorHandleFactory.cpp b/src/backends/cl/ClTensorHandleFactory.cpp
index 87ecdfe1ba..3d9908a1ac 100644
--- a/src/backends/cl/ClTensorHandleFactory.cpp
+++ b/src/backends/cl/ClTensorHandleFactory.cpp
@@ -45,20 +45,26 @@ std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateSubTensorHandle(ITen
boost::polymorphic_downcast<IClTensorHandle *>(&parent), shape, coords);
}
-std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
+std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ const bool IsMemoryManaged) const
{
std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
- tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
-
+ if (IsMemoryManaged)
+ {
+ tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
+ }
return tensorHandle;
}
std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
- DataLayout dataLayout) const
+ DataLayout dataLayout,
+ const bool IsMemoryManaged) const
{
std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
- tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
-
+ if (IsMemoryManaged)
+ {
+ tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
+ }
return tensorHandle;
}
diff --git a/src/backends/cl/ClTensorHandleFactory.hpp b/src/backends/cl/ClTensorHandleFactory.hpp
index 7c3b49bee5..ea3728f7f7 100644
--- a/src/backends/cl/ClTensorHandleFactory.hpp
+++ b/src/backends/cl/ClTensorHandleFactory.hpp
@@ -28,10 +28,12 @@ public:
const TensorShape& subTensorShape,
const unsigned int* subTensorOrigin) const override;
- std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
+ std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ const bool IsMemoryManaged = true) const override;
std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
- DataLayout dataLayout) const override;
+ DataLayout dataLayout,
+ const bool IsMemoryManaged = true) const override;
static const FactoryId& GetIdStatic();
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 8210be251c..536d4dd058 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -82,7 +82,8 @@ ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& mem
{
}
-std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
+std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ const bool IsMemoryManaged) const
{
std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
@@ -91,7 +92,8 @@ std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const Tenso
}
std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
- DataLayout dataLayout) const
+ DataLayout dataLayout,
+ const bool IsMemoryManaged) const
{
std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp
index 493f659eab..c8d58dbb5c 100644
--- a/src/backends/cl/ClWorkloadFactory.hpp
+++ b/src/backends/cl/ClWorkloadFactory.hpp
@@ -31,10 +31,12 @@ public:
TensorShape const& subTensorShape,
unsigned int const* subTensorOrigin) const override;
- std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
+ std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ const bool IsMemoryManaged = true) const override;
std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
- DataLayout dataLayout) const override;
+ DataLayout dataLayout,
+ const bool IsMemoryManaged = true) const override;
std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;