aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/ClTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/ClTensorHandle.hpp')
-rw-r--r--src/armnn/backends/ClTensorHandle.hpp84
1 files changed, 67 insertions, 17 deletions
diff --git a/src/armnn/backends/ClTensorHandle.hpp b/src/armnn/backends/ClTensorHandle.hpp
index 49e18dad59..e3618a3c46 100644
--- a/src/armnn/backends/ClTensorHandle.hpp
+++ b/src/armnn/backends/ClTensorHandle.hpp
@@ -9,9 +9,12 @@
#include <arm_compute/runtime/CL/CLTensor.h>
#include <arm_compute/runtime/CL/CLSubTensor.h>
+#include <arm_compute/runtime/CL/CLMemoryGroup.h>
+#include <arm_compute/runtime/IMemoryGroup.h>
#include <arm_compute/core/TensorShape.h>
#include <arm_compute/core/Coordinates.h>
+#include <boost/polymorphic_pointer_cast.hpp>
namespace armnn
{
@@ -22,9 +25,8 @@ class IClTensorHandle : public ITensorHandle
public:
virtual arm_compute::ICLTensor& GetTensor() = 0;
virtual arm_compute::ICLTensor const& GetTensor() const = 0;
- virtual void Map(bool blocking = true) = 0;
- virtual void UnMap() = 0;
virtual arm_compute::DataType GetDataType() const = 0;
+ virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
};
class ClTensorHandle : public IClTensorHandle
@@ -37,50 +39,98 @@ public:
arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
- virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);};
+ virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
- virtual void Map(bool blocking = true) override {m_Tensor.map(blocking);}
- virtual void UnMap() override { m_Tensor.unmap();}
+ virtual void Manage() override
+ {
+ assert(m_MemoryGroup != nullptr);
+ m_MemoryGroup->manage(&m_Tensor);
+ }
- virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL;}
+ virtual const void* Map(bool blocking = true) const override
+ {
+ const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
+ return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+ }
+ virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
+
+ virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
+
+ virtual ITensorHandle* GetParent() const override { return nullptr; }
virtual arm_compute::DataType GetDataType() const override
{
return m_Tensor.info()->data_type();
}
+ virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
+ {
+ m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
+ }
+
+ TensorShape GetStrides() const override
+ {
+ return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+ }
+
+ TensorShape GetShape() const override
+ {
+ return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+ }
private:
arm_compute::CLTensor m_Tensor;
-
+ std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
};
class ClSubTensorHandle : public IClTensorHandle
{
public:
- ClSubTensorHandle(arm_compute::ICLTensor& parent,
- const arm_compute::TensorShape& shape,
- const arm_compute::Coordinates& coords)
- : m_Tensor(&parent, shape, coords)
+ ClSubTensorHandle(IClTensorHandle* parent,
+ const arm_compute::TensorShape& shape,
+ const arm_compute::Coordinates& coords)
+ : m_Tensor(&parent->GetTensor(), shape, coords)
{
+ parentHandle = parent;
}
arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
- virtual void Allocate() override {};
- virtual void Map(bool blocking = true) override {m_Tensor.map(blocking);}
- virtual void UnMap() override { m_Tensor.unmap();}
+ virtual void Allocate() override {}
+ virtual void Manage() override {}
- virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL;}
+ virtual const void* Map(bool blocking = true) const override
+ {
+ const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
+ return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+ }
+ virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
+
+ virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
+
+ virtual ITensorHandle* GetParent() const override { return parentHandle; }
virtual arm_compute::DataType GetDataType() const override
{
return m_Tensor.info()->data_type();
}
+ virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
+
+ TensorShape GetStrides() const override
+ {
+ return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+ }
+
+ TensorShape GetShape() const override
+ {
+ return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+ }
+
private:
- arm_compute::CLSubTensor m_Tensor;
+ mutable arm_compute::CLSubTensor m_Tensor;
+ ITensorHandle* parentHandle = nullptr;
};
-} \ No newline at end of file
+}