14 #include <arm_compute/runtime/CL/CLTensor.h>
15 #include <arm_compute/runtime/CL/CLSubTensor.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/runtime/MemoryGroup.h>
18 #include <arm_compute/core/TensorShape.h>
19 #include <arm_compute/core/Coordinates.h>
32 m_IsImportEnabled(false)
34 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
40 : m_ImportFlags(importFlags),
42 m_IsImportEnabled(false)
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
47 arm_compute::CLTensor&
GetTensor()
override {
return m_Tensor; }
48 arm_compute::CLTensor
const&
GetTensor()
const override {
return m_Tensor; }
52 if (m_IsImportEnabled)
58 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
66 if (m_IsImportEnabled)
72 assert(m_MemoryGroup !=
nullptr);
73 m_MemoryGroup->manage(&m_Tensor);
77 virtual const void*
Map(
bool blocking =
true)
const override
79 const_cast<arm_compute::CLTensor*
>(&m_Tensor)->map(blocking);
80 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
83 virtual void Unmap()
const override {
const_cast<arm_compute::CLTensor*
>(&m_Tensor)->unmap(); }
89 return m_Tensor.info()->data_type();
92 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
94 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
99 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
104 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
109 m_ImportFlags = importFlags;
114 return m_ImportFlags;
119 m_IsImportEnabled = importEnabledFlag;
142 void CopyOutTo(
void* memory)
const override
147 case arm_compute::DataType::F32:
148 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
149 static_cast<float*
>(memory));
151 case arm_compute::DataType::U8:
152 case arm_compute::DataType::QASYMM8:
153 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
154 static_cast<uint8_t*
>(memory));
156 case arm_compute::DataType::QSYMM8:
157 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
158 case arm_compute::DataType::QASYMM8_SIGNED:
159 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
160 static_cast<int8_t*
>(memory));
162 case arm_compute::DataType::F16:
163 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
166 case arm_compute::DataType::S16:
167 case arm_compute::DataType::QSYMM16:
168 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
169 static_cast<int16_t*
>(memory));
171 case arm_compute::DataType::S32:
172 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
173 static_cast<int32_t*
>(memory));
184 void CopyInFrom(
const void* memory)
override
189 case arm_compute::DataType::F32:
190 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
193 case arm_compute::DataType::U8:
194 case arm_compute::DataType::QASYMM8:
195 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
198 case arm_compute::DataType::F16:
199 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
202 case arm_compute::DataType::S16:
203 case arm_compute::DataType::QSYMM8:
204 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
205 case arm_compute::DataType::QASYMM8_SIGNED:
206 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
209 case arm_compute::DataType::QSYMM16:
210 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
213 case arm_compute::DataType::S32:
214 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
225 arm_compute::CLTensor m_Tensor;
226 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
229 bool m_IsImportEnabled;
236 const arm_compute::TensorShape& shape,
238 : m_Tensor(&parent->
GetTensor(), shape, coords)
240 parentHandle = parent;
243 arm_compute::CLSubTensor&
GetTensor()
override {
return m_Tensor; }
244 arm_compute::CLSubTensor
const&
GetTensor()
const override {
return m_Tensor; }
249 virtual const void*
Map(
bool blocking =
true)
const override
251 const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->map(blocking);
252 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
254 virtual void Unmap()
const override {
const_cast<arm_compute::CLSubTensor*
>(&m_Tensor)->unmap(); }
260 return m_Tensor.info()->data_type();
263 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
267 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
272 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
277 void CopyOutTo(
void* memory)
const override
282 case arm_compute::DataType::F32:
283 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
284 static_cast<float*
>(memory));
286 case arm_compute::DataType::U8:
287 case arm_compute::DataType::QASYMM8:
288 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
289 static_cast<uint8_t*
>(memory));
291 case arm_compute::DataType::F16:
292 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
295 case arm_compute::DataType::QSYMM8:
296 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
297 case arm_compute::DataType::QASYMM8_SIGNED:
298 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
299 static_cast<int8_t*
>(memory));
301 case arm_compute::DataType::S16:
302 case arm_compute::DataType::QSYMM16:
303 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
304 static_cast<int16_t*
>(memory));
306 case arm_compute::DataType::S32:
307 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
308 static_cast<int32_t*
>(memory));
319 void CopyInFrom(
const void* memory)
override
324 case arm_compute::DataType::F32:
325 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
328 case arm_compute::DataType::U8:
329 case arm_compute::DataType::QASYMM8:
330 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
333 case arm_compute::DataType::F16:
334 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
337 case arm_compute::DataType::QSYMM8:
338 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
339 case arm_compute::DataType::QASYMM8_SIGNED:
340 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
343 case arm_compute::DataType::S16:
344 case arm_compute::DataType::QSYMM16:
345 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
348 case arm_compute::DataType::S32:
349 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
360 mutable arm_compute::CLSubTensor m_Tensor;
361 ITensorHandle* parentHandle =
nullptr;