ArmNN
 21.05
ClTensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
9 
10 #include <Half.hpp>
11 
13 
14 #include <arm_compute/runtime/CL/CLTensor.h>
15 #include <arm_compute/runtime/CL/CLSubTensor.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/runtime/MemoryGroup.h>
18 #include <arm_compute/core/TensorShape.h>
19 #include <arm_compute/core/Coordinates.h>
20 
21 namespace armnn
22 {
23 
24 
26 {
27 public:
28  virtual arm_compute::ICLTensor& GetTensor() = 0;
29  virtual arm_compute::ICLTensor const& GetTensor() const = 0;
30  virtual arm_compute::DataType GetDataType() const = 0;
31  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
32 };
33 
35 {
36 public:
37  ClTensorHandle(const TensorInfo& tensorInfo)
38  : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
39  m_Imported(false),
40  m_IsImportEnabled(false)
41  {
42  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
43  }
44 
45  ClTensorHandle(const TensorInfo& tensorInfo,
46  DataLayout dataLayout,
47  MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
48  : m_ImportFlags(importFlags),
49  m_Imported(false),
50  m_IsImportEnabled(false)
51  {
52  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
53  }
54 
55  arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
56  arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
57  virtual void Allocate() override
58  {
59  // If we have enabled Importing, don't allocate the tensor
60  if (m_IsImportEnabled)
61  {
62  throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
63  }
64  else
65  {
66  armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
67  }
68 
69  }
70 
71  virtual void Manage() override
72  {
73  // If we have enabled Importing, don't manage the tensor
74  if (m_IsImportEnabled)
75  {
76  throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
77  }
78  else
79  {
80  assert(m_MemoryGroup != nullptr);
81  m_MemoryGroup->manage(&m_Tensor);
82  }
83  }
84 
85  virtual const void* Map(bool blocking = true) const override
86  {
87  const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
88  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
89  }
90 
91  virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
92 
93  virtual ITensorHandle* GetParent() const override { return nullptr; }
94 
95  virtual arm_compute::DataType GetDataType() const override
96  {
97  return m_Tensor.info()->data_type();
98  }
99 
100  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
101  {
102  m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
103  }
104 
105  TensorShape GetStrides() const override
106  {
107  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
108  }
109 
110  TensorShape GetShape() const override
111  {
112  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
113  }
114 
116  {
117  m_ImportFlags = importFlags;
118  }
119 
121  {
122  return m_ImportFlags;
123  }
124 
125  void SetImportEnabledFlag(bool importEnabledFlag)
126  {
127  m_IsImportEnabled = importEnabledFlag;
128  }
129 
130  virtual bool Import(void* memory, MemorySource source) override
131  {
132  armnn::IgnoreUnused(memory);
133  if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
134  {
135  throw MemoryImportException("ClTensorHandle::Incorrect import flag");
136  }
137  m_Imported = false;
138  return false;
139  }
140 
141 private:
142  // Only used for testing
143  void CopyOutTo(void* memory) const override
144  {
145  const_cast<armnn::ClTensorHandle*>(this)->Map(true);
146  switch(this->GetDataType())
147  {
148  case arm_compute::DataType::F32:
149  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
150  static_cast<float*>(memory));
151  break;
152  case arm_compute::DataType::U8:
153  case arm_compute::DataType::QASYMM8:
154  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
155  static_cast<uint8_t*>(memory));
156  break;
157  case arm_compute::DataType::QSYMM8:
158  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
159  case arm_compute::DataType::QASYMM8_SIGNED:
160  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
161  static_cast<int8_t*>(memory));
162  break;
163  case arm_compute::DataType::F16:
164  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
165  static_cast<armnn::Half*>(memory));
166  break;
167  case arm_compute::DataType::S16:
168  case arm_compute::DataType::QSYMM16:
169  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
170  static_cast<int16_t*>(memory));
171  break;
172  case arm_compute::DataType::S32:
173  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
174  static_cast<int32_t*>(memory));
175  break;
176  default:
177  {
179  }
180  }
181  const_cast<armnn::ClTensorHandle*>(this)->Unmap();
182  }
183 
184  // Only used for testing
185  void CopyInFrom(const void* memory) override
186  {
187  this->Map(true);
188  switch(this->GetDataType())
189  {
190  case arm_compute::DataType::F32:
191  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
192  this->GetTensor());
193  break;
194  case arm_compute::DataType::U8:
195  case arm_compute::DataType::QASYMM8:
196  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
197  this->GetTensor());
198  break;
199  case arm_compute::DataType::F16:
200  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
201  this->GetTensor());
202  break;
203  case arm_compute::DataType::S16:
204  case arm_compute::DataType::QSYMM8:
205  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
206  case arm_compute::DataType::QASYMM8_SIGNED:
207  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
208  this->GetTensor());
209  break;
210  case arm_compute::DataType::QSYMM16:
211  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
212  this->GetTensor());
213  break;
214  case arm_compute::DataType::S32:
215  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
216  this->GetTensor());
217  break;
218  default:
219  {
221  }
222  }
223  this->Unmap();
224  }
225 
226  arm_compute::CLTensor m_Tensor;
227  std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
228  MemorySourceFlags m_ImportFlags;
229  bool m_Imported;
230  bool m_IsImportEnabled;
231 };
232 
234 {
235 public:
237  const arm_compute::TensorShape& shape,
238  const arm_compute::Coordinates& coords)
239  : m_Tensor(&parent->GetTensor(), shape, coords)
240  {
241  parentHandle = parent;
242  }
243 
244  arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
245  arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
246 
247  virtual void Allocate() override {}
248  virtual void Manage() override {}
249 
250  virtual const void* Map(bool blocking = true) const override
251  {
252  const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
253  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
254  }
255  virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
256 
257  virtual ITensorHandle* GetParent() const override { return parentHandle; }
258 
259  virtual arm_compute::DataType GetDataType() const override
260  {
261  return m_Tensor.info()->data_type();
262  }
263 
264  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
265 
266  TensorShape GetStrides() const override
267  {
268  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
269  }
270 
271  TensorShape GetShape() const override
272  {
273  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
274  }
275 
276 private:
277  // Only used for testing
278  void CopyOutTo(void* memory) const override
279  {
280  const_cast<ClSubTensorHandle*>(this)->Map(true);
281  switch(this->GetDataType())
282  {
283  case arm_compute::DataType::F32:
284  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
285  static_cast<float*>(memory));
286  break;
287  case arm_compute::DataType::U8:
288  case arm_compute::DataType::QASYMM8:
289  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
290  static_cast<uint8_t*>(memory));
291  break;
292  case arm_compute::DataType::F16:
293  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
294  static_cast<armnn::Half*>(memory));
295  break;
296  case arm_compute::DataType::QSYMM8:
297  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
298  case arm_compute::DataType::QASYMM8_SIGNED:
299  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
300  static_cast<int8_t*>(memory));
301  break;
302  case arm_compute::DataType::S16:
303  case arm_compute::DataType::QSYMM16:
304  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
305  static_cast<int16_t*>(memory));
306  break;
307  case arm_compute::DataType::S32:
308  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
309  static_cast<int32_t*>(memory));
310  break;
311  default:
312  {
314  }
315  }
316  const_cast<ClSubTensorHandle*>(this)->Unmap();
317  }
318 
319  // Only used for testing
320  void CopyInFrom(const void* memory) override
321  {
322  this->Map(true);
323  switch(this->GetDataType())
324  {
325  case arm_compute::DataType::F32:
326  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
327  this->GetTensor());
328  break;
329  case arm_compute::DataType::U8:
330  case arm_compute::DataType::QASYMM8:
331  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
332  this->GetTensor());
333  break;
334  case arm_compute::DataType::F16:
335  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
336  this->GetTensor());
337  break;
338  case arm_compute::DataType::QSYMM8:
339  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
340  case arm_compute::DataType::QASYMM8_SIGNED:
341  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
342  this->GetTensor());
343  break;
344  case arm_compute::DataType::S16:
345  case arm_compute::DataType::QSYMM16:
346  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
347  this->GetTensor());
348  break;
349  case arm_compute::DataType::S32:
350  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
351  this->GetTensor());
352  break;
353  default:
354  {
356  }
357  }
358  this->Unmap();
359  }
360 
361  mutable arm_compute::CLSubTensor m_Tensor;
362  ITensorHandle* parentHandle = nullptr;
363 };
364 
365 } // namespace armnn
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
MemorySourceFlags GetImportFlags() const override
Get flags describing supported import sources.
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
DataLayout
Definition: Types.hpp:54
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup)=0
DataLayout::NCHW false
virtual arm_compute::DataType GetDataType() const =0
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
virtual arm_compute::DataType GetDataType() const override
void SetImportFlags(MemorySourceFlags importFlags)
unsigned int MemorySourceFlags
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
void SetImportEnabledFlag(bool importEnabledFlag)
virtual void Unmap() const override
Unmap the tensor data.
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup) override
arm_compute::CLSubTensor const & GetTensor() const override
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual void Manage() override
Indicate to the memory manager that this resource is active.
DataType
Definition: Types.hpp:36
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
ClSubTensorHandle(IClTensorHandle *parent, const arm_compute::TensorShape &shape, const arm_compute::Coordinates &coords)
virtual arm_compute::DataType GetDataType() const override
arm_compute::CLTensor const & GetTensor() const override
arm_compute::CLTensor & GetTensor() override
ClTensorHandle(const TensorInfo &tensorInfo)
virtual void CopyOutTo(void *memory) const =0
Testing support to be able to verify and set tensor data content.
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual bool Import(void *memory, MemorySource source) override
Import externally allocated memory.
virtual const void * Map(bool blocking=true) const =0
Map the tensor data for access.
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
virtual void Unmap() const =0
Unmap the tensor data.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
MemorySource
Define the Memory Source to reduce copies.
Definition: Types.hpp:197
virtual void CopyInFrom(const void *memory)=0
ClTensorHandle(const TensorInfo &tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags=static_cast< MemorySourceFlags >(MemorySource::Undefined))
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
arm_compute::CLSubTensor & GetTensor() override
virtual arm_compute::ICLTensor & GetTensor()=0
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...