From 66dbf5b9a3f1bf9d96c0dcaff824b159d1b072a9 Mon Sep 17 00:00:00 2001 From: David Monahan Date: Thu, 11 Mar 2021 11:34:54 +0000 Subject: IVGCVSW-5726 Implement Memory Import Functions in CltensorHandle * Contains a dummy import function as that will be implemented separately Signed-off-by: David Monahan Change-Id: If551b69e832c045c76775a7e5fa25647c2313908 --- src/backends/cl/ClTensorHandle.hpp | 57 +++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) (limited to 'src/backends/cl/ClTensorHandle.hpp') diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp index 0302ef5790..2798de6d0c 100644 --- a/src/backends/cl/ClTensorHandle.hpp +++ b/src/backends/cl/ClTensorHandle.hpp @@ -35,23 +35,43 @@ class ClTensorHandle : public IClTensorHandle { public: ClTensorHandle(const TensorInfo& tensorInfo) + : m_ImportFlags(static_cast(MemorySource::Undefined)), + m_Imported(false), + m_IsImportEnabled(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); } - ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout) + ClTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout, + MemorySourceFlags importFlags = static_cast(MemorySource::Undefined)) + : m_ImportFlags(importFlags), + m_Imported(false), + m_IsImportEnabled(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); } arm_compute::CLTensor& GetTensor() override { return m_Tensor; } arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; } - virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);} + virtual void Allocate() override + { + // If we have enabled Importing, don't allocate the tensor + if (!m_IsImportEnabled) + { + armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); + } + + } virtual void Manage() override { - assert(m_MemoryGroup != nullptr); - m_MemoryGroup->manage(&m_Tensor); + // If we have enabled Importing, don't manage the tensor + if (!m_IsImportEnabled) + { + assert(m_MemoryGroup != nullptr); + m_MemoryGroup->manage(&m_Tensor); + } } virtual const void* Map(bool blocking = true) const override @@ -84,6 +104,32 @@ public: return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + void SetImportFlags(MemorySourceFlags importFlags) + { + m_ImportFlags = importFlags; + } + + MemorySourceFlags GetImportFlags() const override + { + return m_ImportFlags; + } + + void SetImportEnabledFlag(bool importEnabledFlag) + { + m_IsImportEnabled = importEnabledFlag; + } + + virtual bool Import(void* memory, MemorySource source) override + { + armnn::IgnoreUnused(memory); + if (m_ImportFlags & static_cast(source)) + { + throw MemoryImportException("ClTensorHandle::Incorrect import flag"); + } + m_Imported = false; + return false; + } + private: // Only used for testing void CopyOutTo(void* memory) const override @@ -169,6 +215,9 @@ private: arm_compute::CLTensor m_Tensor; std::shared_ptr m_MemoryGroup; + MemorySourceFlags m_ImportFlags; + bool m_Imported; + bool m_IsImportEnabled; }; class ClSubTensorHandle : public IClTensorHandle -- cgit v1.2.1