diff options
Diffstat (limited to 'src/backends/cl/ClTensorHandle.hpp')
-rw-r--r-- | src/backends/cl/ClTensorHandle.hpp | 57 |
1 files changed, 53 insertions, 4 deletions
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp index 0302ef5790..2798de6d0c 100644 --- a/src/backends/cl/ClTensorHandle.hpp +++ b/src/backends/cl/ClTensorHandle.hpp @@ -35,23 +35,43 @@ class ClTensorHandle : public IClTensorHandle { public: ClTensorHandle(const TensorInfo& tensorInfo) + : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)), + m_Imported(false), + m_IsImportEnabled(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); } - ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout) + ClTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout, + MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined)) + : m_ImportFlags(importFlags), + m_Imported(false), + m_IsImportEnabled(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); } arm_compute::CLTensor& GetTensor() override { return m_Tensor; } arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; } - virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);} + virtual void Allocate() override + { + // If we have enabled Importing, don't allocate the tensor + if (!m_IsImportEnabled) + { + armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); + } + + } virtual void Manage() override { - assert(m_MemoryGroup != nullptr); - m_MemoryGroup->manage(&m_Tensor); + // If we have enabled Importing, don't manage the tensor + if (!m_IsImportEnabled) + { + assert(m_MemoryGroup != nullptr); + m_MemoryGroup->manage(&m_Tensor); + } } virtual const void* Map(bool blocking = true) const override @@ -84,6 +104,32 @@ public: return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + void SetImportFlags(MemorySourceFlags importFlags) + { + m_ImportFlags = importFlags; + } + + MemorySourceFlags GetImportFlags() const override + { + return m_ImportFlags; + } + + void SetImportEnabledFlag(bool importEnabledFlag) + { + m_IsImportEnabled = importEnabledFlag; + } + + virtual bool Import(void* memory, MemorySource source) override + { + armnn::IgnoreUnused(memory); + if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + { + throw MemoryImportException("ClTensorHandle::Incorrect import flag"); + } + m_Imported = false; + return false; + } + private: // Only used for testing void CopyOutTo(void* memory) const override @@ -169,6 +215,9 @@ private: arm_compute::CLTensor m_Tensor; std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; + MemorySourceFlags m_ImportFlags; + bool m_Imported; + bool m_IsImportEnabled; }; class ClSubTensorHandle : public IClTensorHandle |