ArmNN
 21.02
NeonTensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <BFloat16.hpp>
8 #include <Half.hpp>
9 
10 #include <armnn/utility/Assert.hpp>
11 
15 
16 #include <arm_compute/runtime/MemoryGroup.h>
17 #include <arm_compute/runtime/IMemoryGroup.h>
18 #include <arm_compute/runtime/Tensor.h>
19 #include <arm_compute/runtime/SubTensor.h>
20 #include <arm_compute/core/TensorShape.h>
21 #include <arm_compute/core/Coordinates.h>
22 
23 namespace armnn
24 {
25 
27 {
28 public:
29  NeonTensorHandle(const TensorInfo& tensorInfo)
30  : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31  m_Imported(false),
32  m_IsImportEnabled(false)
33  {
34  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35  }
36 
37  NeonTensorHandle(const TensorInfo& tensorInfo,
38  DataLayout dataLayout,
39  MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
40  : m_ImportFlags(importFlags),
41  m_Imported(false),
42  m_IsImportEnabled(false)
43 
44  {
45  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46  }
47 
48  arm_compute::ITensor& GetTensor() override { return m_Tensor; }
49  arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
50 
51  virtual void Allocate() override
52  {
53  // If we have enabled Importing, don't Allocate the tensor
54  if (!m_IsImportEnabled)
55  {
56  armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
57  }
58  };
59 
60  virtual void Manage() override
61  {
62  // If we have enabled Importing, don't manage the tensor
63  if (!m_IsImportEnabled)
64  {
65  ARMNN_ASSERT(m_MemoryGroup != nullptr);
66  m_MemoryGroup->manage(&m_Tensor);
67  }
68  }
69 
70  virtual ITensorHandle* GetParent() const override { return nullptr; }
71 
72  virtual arm_compute::DataType GetDataType() const override
73  {
74  return m_Tensor.info()->data_type();
75  }
76 
77  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
78  {
79  m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
80  }
81 
82  virtual const void* Map(bool /* blocking = true */) const override
83  {
84  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
85  }
86 
87  virtual void Unmap() const override {}
88 
89  TensorShape GetStrides() const override
90  {
91  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
92  }
93 
94  TensorShape GetShape() const override
95  {
96  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
97  }
98 
100  {
101  m_ImportFlags = importFlags;
102  }
103 
105  {
106  return m_ImportFlags;
107  }
108 
109  void SetImportEnabledFlag(bool importEnabledFlag)
110  {
111  m_IsImportEnabled = importEnabledFlag;
112  }
113 
114  virtual bool Import(void* memory, MemorySource source) override
115  {
116  if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117  {
118  if (source == MemorySource::Malloc && m_IsImportEnabled)
119  {
120  // Checks the 16 byte memory alignment
121  constexpr uintptr_t alignment = sizeof(size_t);
122  if (reinterpret_cast<uintptr_t>(memory) % alignment)
123  {
124  throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
125  }
126 
127  // m_Tensor not yet Allocated
128  if (!m_Imported && !m_Tensor.buffer())
129  {
130  arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
131  // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
132  // with the Status error message
133  m_Imported = bool(status);
134  if (!m_Imported)
135  {
136  throw MemoryImportException(status.error_description());
137  }
138  return m_Imported;
139  }
140 
141  // m_Tensor.buffer() initially allocated with Allocate().
142  if (!m_Imported && m_Tensor.buffer())
143  {
144  throw MemoryImportException(
145  "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
146  }
147 
148  // m_Tensor.buffer() previously imported.
149  if (m_Imported)
150  {
151  arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
152  // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
153  // with the Status error message
154  m_Imported = bool(status);
155  if (!m_Imported)
156  {
157  throw MemoryImportException(status.error_description());
158  }
159  return m_Imported;
160  }
161  }
162  else
163  {
164  throw MemoryImportException("NeonTensorHandle::Import is disabled");
165  }
166  }
167  else
168  {
169  throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
170  }
171  return false;
172  }
173 
174 private:
175  // Only used for testing
176  void CopyOutTo(void* memory) const override
177  {
178  switch (this->GetDataType())
179  {
180  case arm_compute::DataType::F32:
181  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182  static_cast<float*>(memory));
183  break;
184  case arm_compute::DataType::U8:
185  case arm_compute::DataType::QASYMM8:
186  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
187  static_cast<uint8_t*>(memory));
188  break;
189  case arm_compute::DataType::QASYMM8_SIGNED:
190  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191  static_cast<int8_t*>(memory));
192  break;
193  case arm_compute::DataType::BFLOAT16:
194  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
195  static_cast<armnn::BFloat16*>(memory));
196  break;
197  case arm_compute::DataType::F16:
198  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
199  static_cast<armnn::Half*>(memory));
200  break;
201  case arm_compute::DataType::S16:
202  case arm_compute::DataType::QSYMM16:
203  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
204  static_cast<int16_t*>(memory));
205  break;
206  case arm_compute::DataType::S32:
207  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
208  static_cast<int32_t*>(memory));
209  break;
210  default:
211  {
213  }
214  }
215  }
216 
217  // Only used for testing
218  void CopyInFrom(const void* memory) override
219  {
220  switch (this->GetDataType())
221  {
222  case arm_compute::DataType::F32:
223  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
224  this->GetTensor());
225  break;
226  case arm_compute::DataType::U8:
227  case arm_compute::DataType::QASYMM8:
228  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
229  this->GetTensor());
230  break;
231  case arm_compute::DataType::QASYMM8_SIGNED:
232  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
233  this->GetTensor());
234  break;
235  case arm_compute::DataType::BFLOAT16:
236  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
237  this->GetTensor());
238  break;
239  case arm_compute::DataType::F16:
240  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
241  this->GetTensor());
242  break;
243  case arm_compute::DataType::S16:
244  case arm_compute::DataType::QSYMM16:
245  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
246  this->GetTensor());
247  break;
248  case arm_compute::DataType::S32:
249  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
250  this->GetTensor());
251  break;
252  default:
253  {
255  }
256  }
257  }
258 
259  arm_compute::Tensor m_Tensor;
260  std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
261  MemorySourceFlags m_ImportFlags;
262  bool m_Imported;
263  bool m_IsImportEnabled;
264 };
265 
267 {
268 public:
270  const arm_compute::TensorShape& shape,
271  const arm_compute::Coordinates& coords)
272  : m_Tensor(&parent->GetTensor(), shape, coords)
273  {
274  parentHandle = parent;
275  }
276 
277  arm_compute::ITensor& GetTensor() override { return m_Tensor; }
278  arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
279 
280  virtual void Allocate() override {}
281  virtual void Manage() override {}
282 
283  virtual ITensorHandle* GetParent() const override { return parentHandle; }
284 
285  virtual arm_compute::DataType GetDataType() const override
286  {
287  return m_Tensor.info()->data_type();
288  }
289 
290  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
291 
292  virtual const void* Map(bool /* blocking = true */) const override
293  {
294  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
295  }
296  virtual void Unmap() const override {}
297 
298  TensorShape GetStrides() const override
299  {
300  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
301  }
302 
303  TensorShape GetShape() const override
304  {
305  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
306  }
307 
308 private:
309  // Only used for testing
310  void CopyOutTo(void* memory) const override
311  {
312  switch (this->GetDataType())
313  {
314  case arm_compute::DataType::F32:
315  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
316  static_cast<float*>(memory));
317  break;
318  case arm_compute::DataType::U8:
319  case arm_compute::DataType::QASYMM8:
320  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
321  static_cast<uint8_t*>(memory));
322  break;
323  case arm_compute::DataType::QASYMM8_SIGNED:
324  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
325  static_cast<int8_t*>(memory));
326  break;
327  case arm_compute::DataType::S16:
328  case arm_compute::DataType::QSYMM16:
329  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330  static_cast<int16_t*>(memory));
331  break;
332  case arm_compute::DataType::S32:
333  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
334  static_cast<int32_t*>(memory));
335  break;
336  default:
337  {
339  }
340  }
341  }
342 
343  // Only used for testing
344  void CopyInFrom(const void* memory) override
345  {
346  switch (this->GetDataType())
347  {
348  case arm_compute::DataType::F32:
349  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
350  this->GetTensor());
351  break;
352  case arm_compute::DataType::U8:
353  case arm_compute::DataType::QASYMM8:
354  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
355  this->GetTensor());
356  break;
357  case arm_compute::DataType::QASYMM8_SIGNED:
358  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
359  this->GetTensor());
360  break;
361  case arm_compute::DataType::S16:
362  case arm_compute::DataType::QSYMM16:
363  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
364  this->GetTensor());
365  break;
366  case arm_compute::DataType::S32:
367  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
368  this->GetTensor());
369  break;
370  default:
371  {
373  }
374  }
375  }
376 
377  arm_compute::SubTensor m_Tensor;
378  ITensorHandle* parentHandle = nullptr;
379 };
380 
381 } // namespace armnn
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
DataLayout
Definition: Types.hpp:50
virtual arm_compute::DataType GetDataType() const override
DataLayout::NCHW false
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup) override
virtual arm_compute::DataType GetDataType() const override
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual void Unmap() const override
Unmap the tensor data.
arm_compute::ITensor const & GetTensor() const override
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
unsigned int MemorySourceFlags
arm_compute::ITensor const & GetTensor() const override
virtual void Unmap() const override
Unmap the tensor data.
Copyright (c) 2021 ARM Limited and Contributors.
virtual bool Import(void *memory, MemorySource source) override
Import externally allocated memory.
DataType
Definition: Types.hpp:32
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
NeonSubTensorHandle(IAclTensorHandle *parent, const arm_compute::TensorShape &shape, const arm_compute::Coordinates &coords)
Status
enumeration
Definition: Types.hpp:26
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
arm_compute::ITensor & GetTensor() override
MemorySourceFlags GetImportFlags() const override
Get flags describing supported import sources.
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
virtual const void * Map(bool) const override
Map the tensor data for access.
NeonTensorHandle(const TensorInfo &tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags=static_cast< MemorySourceFlags >(MemorySource::Malloc))
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
NeonTensorHandle(const TensorInfo &tensorInfo)
virtual const void * Map(bool) const override
Map the tensor data for access.
void SetImportFlags(MemorySourceFlags importFlags)
arm_compute::ITensor & GetTensor() override
void SetImportEnabledFlag(bool importEnabledFlag)
virtual void Manage() override
Indicate to the memory manager that this resource is active.