diff options
Diffstat (limited to 'src/backends/neon/NeonTensorHandle.hpp')
-rw-r--r-- | src/backends/neon/NeonTensorHandle.hpp | 97 |
1 files changed, 92 insertions, 5 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index 9077f34888..c3662c1211 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -24,11 +24,20 @@ class NeonTensorHandle : public IAclTensorHandle { public: NeonTensorHandle(const TensorInfo& tensorInfo) + : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)), + m_Imported(false), + m_IsImportEnabled(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); } - NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout) + NeonTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout, + MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc)) + : m_ImportFlags(importFlags), + m_Imported(false), + m_IsImportEnabled(false) + { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); } @@ -38,13 +47,21 @@ public: virtual void Allocate() override { - armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); + // If we have enabled Importing, don't Allocate the tensor + if (!m_IsImportEnabled) + { + armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); + } }; virtual void Manage() override { - BOOST_ASSERT(m_MemoryGroup != nullptr); - m_MemoryGroup->manage(&m_Tensor); + // If we have enabled Importing, don't manage the tensor + if (!m_IsImportEnabled) + { + BOOST_ASSERT(m_MemoryGroup != nullptr); + m_MemoryGroup->manage(&m_Tensor); + } } virtual ITensorHandle* GetParent() const override { return nullptr; } @@ -63,8 +80,8 @@ public: { return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); } - virtual void Unmap() const override {} + virtual void Unmap() const override {} TensorShape GetStrides() const override { @@ -76,6 +93,73 @@ public: return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + void SetImportFlags(MemorySourceFlags importFlags) + { + m_ImportFlags = importFlags; + } + + MemorySourceFlags GetImportFlags() const override + { + return m_ImportFlags; + } + + void SetImportEnabledFlag(bool importEnabledFlag) + { + m_IsImportEnabled = importEnabledFlag; + } + + virtual bool Import(void* memory, MemorySource source) override + { + if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + { + if (source == MemorySource::Malloc && m_IsImportEnabled) + { + // Checks the 16 byte memory alignment + constexpr uintptr_t alignment = sizeof(size_t); + if (reinterpret_cast<uintptr_t>(memory) % alignment) + { + throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory"); + } + + // m_Tensor not yet Allocated + if (!m_Imported && !m_Tensor.buffer()) + { + arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); + // Use the overloaded bool operator of Status to check if it worked, if not throw an exception + // with the Status error message + m_Imported = bool(status); + if (!m_Imported) + { + throw MemoryImportException(status.error_description()); + } + return m_Imported; + } + + // m_Tensor.buffer() initially allocated with Allocate(). + if (!m_Imported && m_Tensor.buffer()) + { + throw MemoryImportException( + "NeonTensorHandle::Import Attempting to import on an already allocated tensor"); + } + + // m_Tensor.buffer() previously imported. + if (m_Imported) + { + arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); + // Use the overloaded bool operator of Status to check if it worked, if not throw an exception + // with the Status error message + m_Imported = bool(status); + if (!m_Imported) + { + throw MemoryImportException(status.error_description()); + } + return m_Imported; + } + } + } + return false; + } + private: // Only used for testing void CopyOutTo(void* memory) const override @@ -131,6 +215,9 @@ private: arm_compute::Tensor m_Tensor; std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; + MemorySourceFlags m_ImportFlags; + bool m_Imported; + bool m_IsImportEnabled; }; class NeonSubTensorHandle : public IAclTensorHandle |