diff options
Diffstat (limited to 'src/runtime/CL/CLTensorAllocator.cpp')
-rw-r--r-- | src/runtime/CL/CLTensorAllocator.cpp | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/src/runtime/CL/CLTensorAllocator.cpp b/src/runtime/CL/CLTensorAllocator.cpp index 8112a7148f..ad165fad7d 100644 --- a/src/runtime/CL/CLTensorAllocator.cpp +++ b/src/runtime/CL/CLTensorAllocator.cpp @@ -25,15 +25,21 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/CL/CLMemoryGroup.h" #include "arm_compute/runtime/CL/CLScheduler.h" using namespace arm_compute; -CLTensorAllocator::CLTensorAllocator() - : _buffer(), _mapping(nullptr) +CLTensorAllocator::CLTensorAllocator(CLTensor *owner) + : _associated_memory_group(nullptr), _buffer(), _mapping(nullptr), _owner(owner) { } +CLTensorAllocator::~CLTensorAllocator() +{ + _buffer = cl::Buffer(); +} + uint8_t *CLTensorAllocator::data() { return _mapping; @@ -47,17 +53,32 @@ const cl::Buffer &CLTensorAllocator::cl_data() const void CLTensorAllocator::allocate() { ARM_COMPUTE_ERROR_ON(_buffer.get() != nullptr); - - _buffer = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, info().total_size()); + if(_associated_memory_group == nullptr) + { + _buffer = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, info().total_size()); + } + else + { + _associated_memory_group->finalize_memory(_owner, reinterpret_cast<void **>(&_buffer()), info().total_size()); + } info().set_is_resizable(false); } void CLTensorAllocator::free() { - ARM_COMPUTE_ERROR_ON(_buffer.get() == nullptr); + if(_associated_memory_group == nullptr) + { + _buffer = cl::Buffer(); + info().set_is_resizable(true); + } +} - _buffer = cl::Buffer(); - info().set_is_resizable(true); +void CLTensorAllocator::set_associated_memory_group(CLMemoryGroup *associated_memory_group) +{ + ARM_COMPUTE_ERROR_ON(associated_memory_group == nullptr); + ARM_COMPUTE_ERROR_ON(_associated_memory_group != nullptr); + ARM_COMPUTE_ERROR_ON(_buffer.get() != nullptr); + _associated_memory_group = associated_memory_group; } uint8_t *CLTensorAllocator::lock() |