16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/Tensor.h>
19 #include <arm_compute/runtime/SubTensor.h>
20 #include <arm_compute/core/TensorShape.h>
21 #include <arm_compute/core/Coordinates.h>
32 m_IsImportEnabled(false),
35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
41 : m_ImportFlags(importFlags),
43 m_IsImportEnabled(false),
48 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
51 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
52 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
57 if (!m_IsImportEnabled)
59 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
66 if (!m_IsImportEnabled)
69 m_MemoryGroup->manage(&m_Tensor);
77 return m_Tensor.info()->data_type();
80 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup)
override
82 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
85 virtual const void*
Map(
bool )
const override
87 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
90 virtual void Unmap()
const override {}
94 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
99 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
104 m_ImportFlags = importFlags;
109 return m_ImportFlags;
114 m_IsImportEnabled = importEnabledFlag;
138 if (!m_Imported && !m_Tensor.buffer())
143 m_Imported = bool(status);
152 if (!m_Imported && m_Tensor.buffer())
155 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
164 m_Imported = bool(status);
186 void CopyOutTo(
void* memory)
const override
190 case arm_compute::DataType::F32:
191 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
192 static_cast<float*
>(memory));
194 case arm_compute::DataType::U8:
195 case arm_compute::DataType::QASYMM8:
196 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
197 static_cast<uint8_t*
>(memory));
199 case arm_compute::DataType::QSYMM8:
200 case arm_compute::DataType::QASYMM8_SIGNED:
201 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
202 static_cast<int8_t*
>(memory));
204 case arm_compute::DataType::BFLOAT16:
205 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
208 case arm_compute::DataType::F16:
209 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
212 case arm_compute::DataType::S16:
213 case arm_compute::DataType::QSYMM16:
214 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
215 static_cast<int16_t*
>(memory));
217 case arm_compute::DataType::S32:
218 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
219 static_cast<int32_t*
>(memory));
229 void CopyInFrom(
const void* memory)
override
233 case arm_compute::DataType::F32:
234 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
237 case arm_compute::DataType::U8:
238 case arm_compute::DataType::QASYMM8:
239 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
242 case arm_compute::DataType::QSYMM8:
243 case arm_compute::DataType::QASYMM8_SIGNED:
244 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
245 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
248 case arm_compute::DataType::BFLOAT16:
249 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::BFloat16*
>(memory),
252 case arm_compute::DataType::F16:
253 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const armnn::Half*
>(memory),
256 case arm_compute::DataType::S16:
257 case arm_compute::DataType::QSYMM16:
258 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
261 case arm_compute::DataType::S32:
262 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
272 arm_compute::Tensor m_Tensor;
273 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
276 bool m_IsImportEnabled;
277 const uintptr_t m_TypeAlignment;
284 const arm_compute::TensorShape& shape,
286 : m_Tensor(&parent->
GetTensor(), shape, coords)
288 parentHandle = parent;
291 arm_compute::ITensor&
GetTensor()
override {
return m_Tensor; }
292 arm_compute::ITensor
const&
GetTensor()
const override {
return m_Tensor; }
301 return m_Tensor.info()->data_type();
304 virtual void SetMemoryGroup(
const std::shared_ptr<arm_compute::IMemoryGroup>&)
override {}
306 virtual const void*
Map(
bool )
const override
308 return static_cast<const void*
>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
310 virtual void Unmap()
const override {}
314 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
319 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
324 void CopyOutTo(
void* memory)
const override
328 case arm_compute::DataType::F32:
329 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
330 static_cast<float*
>(memory));
332 case arm_compute::DataType::U8:
333 case arm_compute::DataType::QASYMM8:
334 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
335 static_cast<uint8_t*
>(memory));
337 case arm_compute::DataType::QSYMM8:
338 case arm_compute::DataType::QASYMM8_SIGNED:
339 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
340 static_cast<int8_t*
>(memory));
342 case arm_compute::DataType::S16:
343 case arm_compute::DataType::QSYMM16:
344 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
345 static_cast<int16_t*
>(memory));
347 case arm_compute::DataType::S32:
348 armcomputetensorutils::CopyArmComputeITensorData(this->
GetTensor(),
349 static_cast<int32_t*
>(memory));
359 void CopyInFrom(
const void* memory)
override
363 case arm_compute::DataType::F32:
364 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const float*
>(memory),
367 case arm_compute::DataType::U8:
368 case arm_compute::DataType::QASYMM8:
369 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const uint8_t*
>(memory),
372 case arm_compute::DataType::QSYMM8:
373 case arm_compute::DataType::QASYMM8_SIGNED:
374 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int8_t*
>(memory),
377 case arm_compute::DataType::S16:
378 case arm_compute::DataType::QSYMM16:
379 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int16_t*
>(memory),
382 case arm_compute::DataType::S32:
383 armcomputetensorutils::CopyArmComputeITensorData(
static_cast<const int32_t*
>(memory),
393 arm_compute::SubTensor m_Tensor;
394 ITensorHandle* parentHandle =
nullptr;