aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClImportTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClImportTensorHandle.hpp')
-rw-r--r--src/backends/cl/ClImportTensorHandle.hpp45
1 files changed, 44 insertions, 1 deletions
diff --git a/src/backends/cl/ClImportTensorHandle.hpp b/src/backends/cl/ClImportTensorHandle.hpp
index 48fb2f7d30..a24ab5656e 100644
--- a/src/backends/cl/ClImportTensorHandle.hpp
+++ b/src/backends/cl/ClImportTensorHandle.hpp
@@ -46,7 +46,7 @@ public:
ClImportTensorHandle(const TensorInfo& tensorInfo,
DataLayout dataLayout,
MemorySourceFlags importFlags)
- : m_ImportFlags(importFlags)
+ : m_ImportFlags(importFlags), m_Imported(false)
{
armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
}
@@ -139,6 +139,48 @@ public:
return ClImport(importProperties, memory, true);
}
+ // Case for importing memory allocated by OpenCl externally directly into the tensor
+ else if (source == MemorySource::Gralloc)
+ {
+ // m_Tensor not yet Allocated
+ if (!m_Imported && !m_Tensor.buffer())
+ {
+ // Importing memory allocated by OpenCl into the tensor directly.
+ arm_compute::Status status =
+ m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory)));
+ m_Imported = bool(status);
+ if (!m_Imported)
+ {
+ throw MemoryImportException(status.error_description());
+ }
+ return m_Imported;
+ }
+
+ // m_Tensor.buffer() initially allocated with Allocate().
+ else if (!m_Imported && m_Tensor.buffer())
+ {
+ throw MemoryImportException(
+ "ClImportTensorHandle::Import Attempting to import on an already allocated tensor");
+ }
+
+ // m_Tensor.buffer() previously imported.
+ else if (m_Imported)
+ {
+ // Importing memory allocated by OpenCl into the tensor directly.
+ arm_compute::Status status =
+ m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory)));
+ m_Imported = bool(status);
+ if (!m_Imported)
+ {
+ throw MemoryImportException(status.error_description());
+ }
+ return m_Imported;
+ }
+ else
+ {
+ throw MemoryImportException("ClImportTensorHandle::Failed to Import Gralloc Memory");
+ }
+ }
else
{
throw MemoryImportException("ClImportTensorHandle::Import flag is not supported");
@@ -276,6 +318,7 @@ private:
arm_compute::CLTensor m_Tensor;
MemorySourceFlags m_ImportFlags;
+ bool m_Imported;
};
class ClImportSubTensorHandle : public IClImportTensorHandle