ArmNN
 22.02
ClTensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
9 
10 #include <Half.hpp>
11 
13 
14 #include <arm_compute/runtime/CL/CLTensor.h>
15 #include <arm_compute/runtime/CL/CLSubTensor.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/runtime/MemoryGroup.h>
18 #include <arm_compute/core/TensorShape.h>
19 #include <arm_compute/core/Coordinates.h>
20 
21 #include <cl/IClTensorHandle.hpp>
22 
23 namespace armnn
24 {
25 
27 {
28 public:
29  ClTensorHandle(const TensorInfo& tensorInfo)
30  : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
31  m_Imported(false),
32  m_IsImportEnabled(false)
33  {
34  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35  }
36 
37  ClTensorHandle(const TensorInfo& tensorInfo,
38  DataLayout dataLayout,
39  MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
40  : m_ImportFlags(importFlags),
41  m_Imported(false),
42  m_IsImportEnabled(false)
43  {
44  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45  }
46 
47  arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
48  arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
49  virtual void Allocate() override
50  {
51  // If we have enabled Importing, don't allocate the tensor
52  if (m_IsImportEnabled)
53  {
54  throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
55  }
56  else
57  {
58  armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
59  }
60 
61  }
62 
63  virtual void Manage() override
64  {
65  // If we have enabled Importing, don't manage the tensor
66  if (m_IsImportEnabled)
67  {
68  throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
69  }
70  else
71  {
72  assert(m_MemoryGroup != nullptr);
73  m_MemoryGroup->manage(&m_Tensor);
74  }
75  }
76 
77  virtual const void* Map(bool blocking = true) const override
78  {
79  const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
80  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
81  }
82 
83  virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
84 
85  virtual ITensorHandle* GetParent() const override { return nullptr; }
86 
87  virtual arm_compute::DataType GetDataType() const override
88  {
89  return m_Tensor.info()->data_type();
90  }
91 
92  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
93  {
94  m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
95  }
96 
97  TensorShape GetStrides() const override
98  {
99  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
100  }
101 
102  TensorShape GetShape() const override
103  {
104  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
105  }
106 
108  {
109  m_ImportFlags = importFlags;
110  }
111 
113  {
114  return m_ImportFlags;
115  }
116 
117  void SetImportEnabledFlag(bool importEnabledFlag)
118  {
119  m_IsImportEnabled = importEnabledFlag;
120  }
121 
122  virtual bool Import(void* memory, MemorySource source) override
123  {
124  armnn::IgnoreUnused(memory);
125  if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
126  {
127  throw MemoryImportException("ClTensorHandle::Incorrect import flag");
128  }
129  m_Imported = false;
130  return false;
131  }
132 
133  virtual bool CanBeImported(void* memory, MemorySource source) override
134  {
135  // This TensorHandle can never import.
136  armnn::IgnoreUnused(memory, source);
137  return false;
138  }
139 
140 private:
141  // Only used for testing
142  void CopyOutTo(void* memory) const override
143  {
144  const_cast<armnn::ClTensorHandle*>(this)->Map(true);
145  switch(this->GetDataType())
146  {
147  case arm_compute::DataType::F32:
148  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
149  static_cast<float*>(memory));
150  break;
151  case arm_compute::DataType::U8:
152  case arm_compute::DataType::QASYMM8:
153  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
154  static_cast<uint8_t*>(memory));
155  break;
156  case arm_compute::DataType::QSYMM8:
157  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
158  case arm_compute::DataType::QASYMM8_SIGNED:
159  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
160  static_cast<int8_t*>(memory));
161  break;
162  case arm_compute::DataType::F16:
163  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
164  static_cast<armnn::Half*>(memory));
165  break;
166  case arm_compute::DataType::S16:
167  case arm_compute::DataType::QSYMM16:
168  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
169  static_cast<int16_t*>(memory));
170  break;
171  case arm_compute::DataType::S32:
172  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173  static_cast<int32_t*>(memory));
174  break;
175  default:
176  {
178  }
179  }
180  const_cast<armnn::ClTensorHandle*>(this)->Unmap();
181  }
182 
183  // Only used for testing
184  void CopyInFrom(const void* memory) override
185  {
186  this->Map(true);
187  switch(this->GetDataType())
188  {
189  case arm_compute::DataType::F32:
190  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
191  this->GetTensor());
192  break;
193  case arm_compute::DataType::U8:
194  case arm_compute::DataType::QASYMM8:
195  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
196  this->GetTensor());
197  break;
198  case arm_compute::DataType::F16:
199  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
200  this->GetTensor());
201  break;
202  case arm_compute::DataType::S16:
203  case arm_compute::DataType::QSYMM8:
204  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
205  case arm_compute::DataType::QASYMM8_SIGNED:
206  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
207  this->GetTensor());
208  break;
209  case arm_compute::DataType::QSYMM16:
210  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
211  this->GetTensor());
212  break;
213  case arm_compute::DataType::S32:
214  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
215  this->GetTensor());
216  break;
217  default:
218  {
220  }
221  }
222  this->Unmap();
223  }
224 
225  arm_compute::CLTensor m_Tensor;
226  std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
227  MemorySourceFlags m_ImportFlags;
228  bool m_Imported;
229  bool m_IsImportEnabled;
230 };
231 
233 {
234 public:
236  const arm_compute::TensorShape& shape,
237  const arm_compute::Coordinates& coords)
238  : m_Tensor(&parent->GetTensor(), shape, coords)
239  {
240  parentHandle = parent;
241  }
242 
243  arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
244  arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
245 
246  virtual void Allocate() override {}
247  virtual void Manage() override {}
248 
249  virtual const void* Map(bool blocking = true) const override
250  {
251  const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
252  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
253  }
254  virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
255 
256  virtual ITensorHandle* GetParent() const override { return parentHandle; }
257 
258  virtual arm_compute::DataType GetDataType() const override
259  {
260  return m_Tensor.info()->data_type();
261  }
262 
263  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
264 
265  TensorShape GetStrides() const override
266  {
267  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
268  }
269 
270  TensorShape GetShape() const override
271  {
272  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
273  }
274 
275 private:
276  // Only used for testing
277  void CopyOutTo(void* memory) const override
278  {
279  const_cast<ClSubTensorHandle*>(this)->Map(true);
280  switch(this->GetDataType())
281  {
282  case arm_compute::DataType::F32:
283  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
284  static_cast<float*>(memory));
285  break;
286  case arm_compute::DataType::U8:
287  case arm_compute::DataType::QASYMM8:
288  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
289  static_cast<uint8_t*>(memory));
290  break;
291  case arm_compute::DataType::F16:
292  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
293  static_cast<armnn::Half*>(memory));
294  break;
295  case arm_compute::DataType::QSYMM8:
296  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
297  case arm_compute::DataType::QASYMM8_SIGNED:
298  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
299  static_cast<int8_t*>(memory));
300  break;
301  case arm_compute::DataType::S16:
302  case arm_compute::DataType::QSYMM16:
303  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304  static_cast<int16_t*>(memory));
305  break;
306  case arm_compute::DataType::S32:
307  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
308  static_cast<int32_t*>(memory));
309  break;
310  default:
311  {
313  }
314  }
315  const_cast<ClSubTensorHandle*>(this)->Unmap();
316  }
317 
318  // Only used for testing
319  void CopyInFrom(const void* memory) override
320  {
321  this->Map(true);
322  switch(this->GetDataType())
323  {
324  case arm_compute::DataType::F32:
325  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
326  this->GetTensor());
327  break;
328  case arm_compute::DataType::U8:
329  case arm_compute::DataType::QASYMM8:
330  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
331  this->GetTensor());
332  break;
333  case arm_compute::DataType::F16:
334  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
335  this->GetTensor());
336  break;
337  case arm_compute::DataType::QSYMM8:
338  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
339  case arm_compute::DataType::QASYMM8_SIGNED:
340  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
341  this->GetTensor());
342  break;
343  case arm_compute::DataType::S16:
344  case arm_compute::DataType::QSYMM16:
345  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
346  this->GetTensor());
347  break;
348  case arm_compute::DataType::S32:
349  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
350  this->GetTensor());
351  break;
352  default:
353  {
355  }
356  }
357  this->Unmap();
358  }
359 
360  mutable arm_compute::CLSubTensor m_Tensor;
361  ITensorHandle* parentHandle = nullptr;
362 };
363 
364 } // namespace armnn
virtual bool CanBeImported(void *memory, MemorySource source) override
Implementations must determine if this memory block can be imported.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
MemorySourceFlags GetImportFlags() const override
Get flags describing supported import sources.
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
DataLayout
Definition: Types.hpp:49
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
virtual arm_compute::DataType GetDataType() const override
void SetImportFlags(MemorySourceFlags importFlags)
unsigned int MemorySourceFlags
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
void SetImportEnabledFlag(bool importEnabledFlag)
virtual void Unmap() const override
Unmap the tensor data.
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup) override
arm_compute::CLSubTensor const & GetTensor() const override
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
virtual void Manage() override
Indicate to the memory manager that this resource is active.
DataType
Definition: Types.hpp:35
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
ClSubTensorHandle(IClTensorHandle *parent, const arm_compute::TensorShape &shape, const arm_compute::Coordinates &coords)
virtual arm_compute::DataType GetDataType() const override
arm_compute::CLTensor const & GetTensor() const override
arm_compute::CLTensor & GetTensor() override
ClTensorHandle(const TensorInfo &tensorInfo)
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual bool Import(void *memory, MemorySource source) override
Import externally allocated memory.
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
MemorySource
Define the Memory Source to reduce copies.
Definition: Types.hpp:217
ClTensorHandle(const TensorInfo &tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags=static_cast< MemorySourceFlags >(MemorySource::Undefined))
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
arm_compute::CLSubTensor & GetTensor() override
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
virtual void Unmap() const override
Unmap the tensor data.
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...