12 #include <arm_compute/runtime/CL/CLTensor.h> 13 #include <arm_compute/runtime/CL/CLSubTensor.h> 14 #include <arm_compute/runtime/IMemoryGroup.h> 15 #include <arm_compute/runtime/MemoryGroup.h> 16 #include <arm_compute/core/TensorShape.h> 17 #include <arm_compute/core/Coordinates.h> 19 #include <boost/polymorphic_pointer_cast.hpp> 28 virtual arm_compute::ICLTensor&
GetTensor() = 0;
29 virtual arm_compute::ICLTensor
const&
GetTensor()
const = 0;
31 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
39 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
47 arm_compute::CLTensor&
GetTensor()
override {
return m_Tensor; }
48 arm_compute::CLTensor
const&
GetTensor()
const override {
return m_Tensor; }
49 virtual void Allocate()
override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
53 assert(m_MemoryGroup !=
nullptr);
54 m_MemoryGroup->manage(&m_Tensor);
57 virtual const void*
Map(
bool blocking =
true)
const override 59 const_cast<arm_compute::CLTensor*
>(&m_Tensor)->map(blocking);
60 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
63 virtual void Unmap()
const override {
const_cast<arm_compute::CLTensor*
>(&m_Tensor)->unmap(); }
69 return m_Tensor.info()->data_type();
72 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override 74 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
79 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
84 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
89 void CopyOutTo(
void* memory)
const override 94 case arm_compute::DataType::F32:
95 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
96 static_cast<float*>(memory));
98 case arm_compute::DataType::U8:
99 case arm_compute::DataType::QASYMM8:
100 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
101 static_cast<uint8_t*>(memory));
103 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
104 case arm_compute::DataType::QASYMM8_SIGNED:
105 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
106 static_cast<int8_t*>(memory));
108 case arm_compute::DataType::F16:
109 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
110 static_cast<armnn::Half*>(memory));
112 case arm_compute::DataType::S16:
113 case arm_compute::DataType::QSYMM16:
114 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
115 static_cast<int16_t*>(memory));
117 case arm_compute::DataType::S32:
118 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
119 static_cast<int32_t*>(memory));
135 case arm_compute::DataType::F32:
136 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
139 case arm_compute::DataType::U8:
140 case arm_compute::DataType::QASYMM8:
141 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
144 case arm_compute::DataType::F16:
145 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
148 case arm_compute::DataType::S16:
149 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
150 case arm_compute::DataType::QASYMM8_SIGNED:
151 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
154 case arm_compute::DataType::QSYMM16:
155 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
158 case arm_compute::DataType::S32:
159 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
170 arm_compute::CLTensor m_Tensor;
171 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
178 const arm_compute::TensorShape& shape,
180 : m_Tensor(&parent->
GetTensor(), shape, coords)
182 parentHandle = parent;
185 arm_compute::CLSubTensor&
GetTensor()
override {
return m_Tensor; }
186 arm_compute::CLSubTensor
const&
GetTensor()
const override {
return m_Tensor; }
191 virtual const void*
Map(
bool blocking =
true)
const override 193 const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->map(blocking);
194 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
196 virtual void Unmap()
const override {
const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->unmap(); }
202 return m_Tensor.info()->data_type();
205 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
209 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
214 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
219 void CopyOutTo(
void* memory)
const override 224 case arm_compute::DataType::F32:
225 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
226 static_cast<float*>(memory));
228 case arm_compute::DataType::U8:
229 case arm_compute::DataType::QASYMM8:
230 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
231 static_cast<uint8_t*>(memory));
233 case arm_compute::DataType::F16:
234 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
235 static_cast<armnn::Half*>(memory));
237 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
238 case arm_compute::DataType::QASYMM8_SIGNED:
239 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
240 static_cast<int8_t*>(memory));
242 case arm_compute::DataType::S16:
243 case arm_compute::DataType::QSYMM16:
244 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
245 static_cast<int16_t*>(memory));
247 case arm_compute::DataType::S32:
248 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
249 static_cast<int32_t*>(memory));
265 case arm_compute::DataType::F32:
266 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
269 case arm_compute::DataType::U8:
270 case arm_compute::DataType::QASYMM8:
271 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
274 case arm_compute::DataType::F16:
275 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
278 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
279 case arm_compute::DataType::QASYMM8_SIGNED:
280 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
283 case arm_compute::DataType::S16:
284 case arm_compute::DataType::QSYMM16:
285 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
288 case arm_compute::DataType::S32:
289 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
300 mutable arm_compute::CLSubTensor m_Tensor;
virtual const void * Map(bool blocking=true) const override
arm_compute::CLSubTensor const & GetTensor() const override
ClSubTensorHandle(IClTensorHandle *parent, const arm_compute::TensorShape &shape, const arm_compute::Coordinates &coords)
TensorShape GetShape() const override
arm_compute::CLTensor const & GetTensor() const override
virtual const void * Map(bool blocking=true) const override
virtual arm_compute::DataType GetDataType() const override
arm_compute::CLTensor & GetTensor() override
TensorShape GetStrides() const override
ClTensorHandle(const TensorInfo &tensorInfo, DataLayout dataLayout)
virtual void Unmap() const override
Unmap the tensor data.
virtual const void * Map(bool blocking=true) const =0
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup)=0
virtual void CopyOutTo(void *memory) const =0
virtual void Manage() override
TensorShape GetShape() const override
ClTensorHandle(const TensorInfo &tensorInfo)
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
virtual arm_compute::ICLTensor & GetTensor()=0
arm_compute::CLSubTensor & GetTensor() override
virtual void Unmap() const =0
Unmap the tensor data.
virtual void CopyInFrom(const void *memory)=0
TensorShape GetStrides() const override
virtual void Unmap() const override
Unmap the tensor data.
virtual void Manage() override
virtual arm_compute::DataType GetDataType() const =0
virtual void Allocate() override
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup) override
virtual void Allocate() override
virtual arm_compute::DataType GetDataType() const override
virtual ITensorHandle * GetParent() const override
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
virtual ITensorHandle * GetParent() const override