aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Monahan <david.monahan@arm.com>2021-03-11 11:34:54 +0000
committerDavid Monahan <david.monahan@arm.com>2021-03-11 11:37:47 +0000
commit66dbf5b9a3f1bf9d96c0dcaff824b159d1b072a9 (patch)
treef87cc72c2254b0f6d48c27f2a89522c179ff9af4
parentcb89717c4858cc97fb62123d95df7f6af4118244 (diff)
downloadarmnn-66dbf5b9a3f1bf9d96c0dcaff824b159d1b072a9.tar.gz
IVGCVSW-5726 Implement Memory Import Functions in CltensorHandle
* Contains a dummy import function as that will be implemented separately Signed-off-by: David Monahan <david.monahan@arm.com> Change-Id: If551b69e832c045c76775a7e5fa25647c2313908
-rw-r--r--src/backends/cl/ClTensorHandle.hpp57
1 files changed, 53 insertions, 4 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index 0302ef5790..2798de6d0c 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -35,23 +35,43 @@ class ClTensorHandle : public IClTensorHandle
{
public:
ClTensorHandle(const TensorInfo& tensorInfo)
+ : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
+ m_Imported(false),
+ m_IsImportEnabled(false)
{
armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
}
- ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
+ ClTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout,
+ MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
+ : m_ImportFlags(importFlags),
+ m_Imported(false),
+ m_IsImportEnabled(false)
{
armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
}
arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
- virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
+ virtual void Allocate() override
+ {
+ // If we have enabled Importing, don't allocate the tensor
+ if (!m_IsImportEnabled)
+ {
+ armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
+ }
+
+ }
virtual void Manage() override
{
- assert(m_MemoryGroup != nullptr);
- m_MemoryGroup->manage(&m_Tensor);
+ // If we have enabled Importing, don't manage the tensor
+ if (!m_IsImportEnabled)
+ {
+ assert(m_MemoryGroup != nullptr);
+ m_MemoryGroup->manage(&m_Tensor);
+ }
}
virtual const void* Map(bool blocking = true) const override
@@ -84,6 +104,32 @@ public:
return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
}
+ void SetImportFlags(MemorySourceFlags importFlags)
+ {
+ m_ImportFlags = importFlags;
+ }
+
+ MemorySourceFlags GetImportFlags() const override
+ {
+ return m_ImportFlags;
+ }
+
+ void SetImportEnabledFlag(bool importEnabledFlag)
+ {
+ m_IsImportEnabled = importEnabledFlag;
+ }
+
+ virtual bool Import(void* memory, MemorySource source) override
+ {
+ armnn::IgnoreUnused(memory);
+ if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ {
+ throw MemoryImportException("ClTensorHandle::Incorrect import flag");
+ }
+ m_Imported = false;
+ return false;
+ }
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -169,6 +215,9 @@ private:
arm_compute::CLTensor m_Tensor;
std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
+ MemorySourceFlags m_ImportFlags;
+ bool m_Imported;
+ bool m_IsImportEnabled;
};
class ClSubTensorHandle : public IClTensorHandle