diff options
Diffstat (limited to 'src/backends/cl/ClImportTensorHandle.hpp')
-rw-r--r-- | src/backends/cl/ClImportTensorHandle.hpp | 45 |
1 files changed, 44 insertions, 1 deletions
diff --git a/src/backends/cl/ClImportTensorHandle.hpp b/src/backends/cl/ClImportTensorHandle.hpp index 48fb2f7d30..a24ab5656e 100644 --- a/src/backends/cl/ClImportTensorHandle.hpp +++ b/src/backends/cl/ClImportTensorHandle.hpp @@ -46,7 +46,7 @@ public: ClImportTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags) - : m_ImportFlags(importFlags) + : m_ImportFlags(importFlags), m_Imported(false) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); } @@ -139,6 +139,48 @@ public: return ClImport(importProperties, memory, true); } + // Case for importing memory allocated by OpenCl externally directly into the tensor + else if (source == MemorySource::Gralloc) + { + // m_Tensor not yet Allocated + if (!m_Imported && !m_Tensor.buffer()) + { + // Importing memory allocated by OpenCl into the tensor directly. + arm_compute::Status status = + m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory))); + m_Imported = bool(status); + if (!m_Imported) + { + throw MemoryImportException(status.error_description()); + } + return m_Imported; + } + + // m_Tensor.buffer() initially allocated with Allocate(). + else if (!m_Imported && m_Tensor.buffer()) + { + throw MemoryImportException( + "ClImportTensorHandle::Import Attempting to import on an already allocated tensor"); + } + + // m_Tensor.buffer() previously imported. + else if (m_Imported) + { + // Importing memory allocated by OpenCl into the tensor directly. + arm_compute::Status status = + m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory))); + m_Imported = bool(status); + if (!m_Imported) + { + throw MemoryImportException(status.error_description()); + } + return m_Imported; + } + else + { + throw MemoryImportException("ClImportTensorHandle::Failed to Import Gralloc Memory"); + } + } else { throw MemoryImportException("ClImportTensorHandle::Import flag is not supported"); @@ -276,6 +318,7 @@ private: arm_compute::CLTensor m_Tensor; MemorySourceFlags m_ImportFlags; + bool m_Imported; }; class ClImportSubTensorHandle : public IClImportTensorHandle |