ArmNN
 23.05
ClTensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
9 
10 #include <Half.hpp>
11 
13 
14 #include <arm_compute/runtime/CL/CLTensor.h>
15 #include <arm_compute/runtime/CL/CLSubTensor.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/runtime/MemoryGroup.h>
18 #include <arm_compute/core/TensorShape.h>
19 #include <arm_compute/core/Coordinates.h>
20 
22 
23 namespace armnn
24 {
25 
27 {
28 public:
29  ClTensorHandle(const TensorInfo& tensorInfo)
30  : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
31  m_Imported(false),
32  m_IsImportEnabled(false)
33  {
34  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35  }
36 
37  ClTensorHandle(const TensorInfo& tensorInfo,
38  DataLayout dataLayout,
40  : m_ImportFlags(importFlags),
41  m_Imported(false),
42  m_IsImportEnabled(false)
43  {
44  armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45  }
46 
47  arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
48  arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
49  virtual void Allocate() override
50  {
51  // If we have enabled Importing, don't allocate the tensor
52  if (m_IsImportEnabled)
53  {
54  throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
55  }
56  else
57  {
58  armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
59  }
60 
61  }
62 
63  virtual void Manage() override
64  {
65  // If we have enabled Importing, don't manage the tensor
66  if (m_IsImportEnabled)
67  {
68  throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
69  }
70  else
71  {
72  assert(m_MemoryGroup != nullptr);
73  m_MemoryGroup->manage(&m_Tensor);
74  }
75  }
76 
77  virtual const void* Map(bool blocking = true) const override
78  {
79  const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
80  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
81  }
82 
83  virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
84 
85  virtual ITensorHandle* GetParent() const override { return nullptr; }
86 
87  virtual arm_compute::DataType GetDataType() const override
88  {
89  return m_Tensor.info()->data_type();
90  }
91 
92  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
93  {
94  m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
95  }
96 
97  TensorShape GetStrides() const override
98  {
99  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
100  }
101 
102  TensorShape GetShape() const override
103  {
104  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
105  }
106 
108  {
109  m_ImportFlags = importFlags;
110  }
111 
113  {
114  return m_ImportFlags;
115  }
116 
117  void SetImportEnabledFlag(bool importEnabledFlag)
118  {
119  m_IsImportEnabled = importEnabledFlag;
120  }
121 
122  virtual bool Import(void* memory, MemorySource source) override
123  {
124  armnn::IgnoreUnused(memory);
125  if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
126  {
127  throw MemoryImportException("ClTensorHandle::Incorrect import flag");
128  }
129  m_Imported = false;
130  return false;
131  }
132 
133  virtual bool CanBeImported(void* memory, MemorySource source) override
134  {
135  // This TensorHandle can never import.
136  armnn::IgnoreUnused(memory, source);
137  return false;
138  }
139 
140 private:
141  // Only used for testing
142  void CopyOutTo(void* memory) const override
143  {
144  const_cast<armnn::ClTensorHandle*>(this)->Map(true);
145  switch(this->GetDataType())
146  {
147  case arm_compute::DataType::F32:
148  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
149  static_cast<float*>(memory));
150  break;
151  case arm_compute::DataType::U8:
152  case arm_compute::DataType::QASYMM8:
153  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
154  static_cast<uint8_t*>(memory));
155  break;
156  case arm_compute::DataType::QSYMM8:
157  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
158  case arm_compute::DataType::QASYMM8_SIGNED:
159  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
160  static_cast<int8_t*>(memory));
161  break;
162  case arm_compute::DataType::F16:
163  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
164  static_cast<armnn::Half*>(memory));
165  break;
166  case arm_compute::DataType::S16:
167  case arm_compute::DataType::QSYMM16:
168  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
169  static_cast<int16_t*>(memory));
170  break;
171  case arm_compute::DataType::S32:
172  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173  static_cast<int32_t*>(memory));
174  break;
175  default:
176  {
178  }
179  }
180  const_cast<armnn::ClTensorHandle*>(this)->Unmap();
181  }
182 
183  // Only used for testing
184  void CopyInFrom(const void* memory) override
185  {
186  this->Map(true);
187  switch(this->GetDataType())
188  {
189  case arm_compute::DataType::F32:
190  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
191  this->GetTensor());
192  break;
193  case arm_compute::DataType::U8:
194  case arm_compute::DataType::QASYMM8:
195  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
196  this->GetTensor());
197  break;
198  case arm_compute::DataType::F16:
199  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
200  this->GetTensor());
201  break;
202  case arm_compute::DataType::S16:
203  case arm_compute::DataType::QSYMM8:
204  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
205  case arm_compute::DataType::QASYMM8_SIGNED:
206  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
207  this->GetTensor());
208  break;
209  case arm_compute::DataType::QSYMM16:
210  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
211  this->GetTensor());
212  break;
213  case arm_compute::DataType::S32:
214  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
215  this->GetTensor());
216  break;
217  default:
218  {
220  }
221  }
222  this->Unmap();
223  }
224 
225  arm_compute::CLTensor m_Tensor;
226  std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
227  MemorySourceFlags m_ImportFlags;
228  bool m_Imported;
229  bool m_IsImportEnabled;
230 };
231 
233 {
234 public:
236  const arm_compute::TensorShape& shape,
237  const arm_compute::Coordinates& coords)
238  : m_Tensor(&parent->GetTensor(), shape, coords)
239  {
240  parentHandle = parent;
241  }
242 
243  arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
244  arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
245 
246  virtual void Allocate() override {}
247  virtual void Manage() override {}
248 
249  virtual const void* Map(bool blocking = true) const override
250  {
251  const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
252  return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
253  }
254  virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
255 
256  virtual ITensorHandle* GetParent() const override { return parentHandle; }
257 
258  virtual arm_compute::DataType GetDataType() const override
259  {
260  return m_Tensor.info()->data_type();
261  }
262 
263  virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
264 
265  TensorShape GetStrides() const override
266  {
267  return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
268  }
269 
270  TensorShape GetShape() const override
271  {
272  return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
273  }
274 
275 private:
276  // Only used for testing
277  void CopyOutTo(void* memory) const override
278  {
279  const_cast<ClSubTensorHandle*>(this)->Map(true);
280  switch(this->GetDataType())
281  {
282  case arm_compute::DataType::F32:
283  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
284  static_cast<float*>(memory));
285  break;
286  case arm_compute::DataType::U8:
287  case arm_compute::DataType::QASYMM8:
288  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
289  static_cast<uint8_t*>(memory));
290  break;
291  case arm_compute::DataType::F16:
292  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
293  static_cast<armnn::Half*>(memory));
294  break;
295  case arm_compute::DataType::QSYMM8:
296  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
297  case arm_compute::DataType::QASYMM8_SIGNED:
298  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
299  static_cast<int8_t*>(memory));
300  break;
301  case arm_compute::DataType::S16:
302  case arm_compute::DataType::QSYMM16:
303  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304  static_cast<int16_t*>(memory));
305  break;
306  case arm_compute::DataType::S32:
307  armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
308  static_cast<int32_t*>(memory));
309  break;
310  default:
311  {
313  }
314  }
315  const_cast<ClSubTensorHandle*>(this)->Unmap();
316  }
317 
318  // Only used for testing
319  void CopyInFrom(const void* memory) override
320  {
321  this->Map(true);
322  switch(this->GetDataType())
323  {
324  case arm_compute::DataType::F32:
325  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
326  this->GetTensor());
327  break;
328  case arm_compute::DataType::U8:
329  case arm_compute::DataType::QASYMM8:
330  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
331  this->GetTensor());
332  break;
333  case arm_compute::DataType::F16:
334  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
335  this->GetTensor());
336  break;
337  case arm_compute::DataType::QSYMM8:
338  case arm_compute::DataType::QSYMM8_PER_CHANNEL:
339  case arm_compute::DataType::QASYMM8_SIGNED:
340  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
341  this->GetTensor());
342  break;
343  case arm_compute::DataType::S16:
344  case arm_compute::DataType::QSYMM16:
345  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
346  this->GetTensor());
347  break;
348  case arm_compute::DataType::S32:
349  armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
350  this->GetTensor());
351  break;
352  default:
353  {
355  }
356  }
357  this->Unmap();
358  }
359 
360  mutable arm_compute::CLSubTensor m_Tensor;
361  ITensorHandle* parentHandle = nullptr;
362 };
363 
364 } // namespace armnn
armnn::ClSubTensorHandle::Manage
virtual void Manage() override
Indicate to the memory manager that this resource is active.
Definition: ClTensorHandle.hpp:247
armnn::ClSubTensorHandle::GetTensor
arm_compute::CLSubTensor & GetTensor() override
Definition: ClTensorHandle.hpp:243
armnn::MemorySource::Undefined
@ Undefined
armnn::ClSubTensorHandle::Allocate
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
Definition: ClTensorHandle.hpp:246
armnn::IClTensorHandle
Definition: IClTensorHandle.hpp:13
armnn::DataLayout
DataLayout
Definition: Types.hpp:62
armnn::ClSubTensorHandle::Map
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
Definition: ClTensorHandle.hpp:249
armnn::ClTensorHandle::GetDataType
virtual arm_compute::DataType GetDataType() const override
Definition: ClTensorHandle.hpp:87
armnn::ClTensorHandle::SetMemoryGroup
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &memoryGroup) override
Definition: ClTensorHandle.hpp:92
armnn::ClSubTensorHandle::GetDataType
virtual arm_compute::DataType GetDataType() const override
Definition: ClTensorHandle.hpp:258
PolymorphicDowncast.hpp
armnn::MemoryImportException
Definition: Exceptions.hpp:125
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn::ClTensorHandle::Allocate
virtual void Allocate() override
Indicate to the memory manager that this resource is no longer active.
Definition: ClTensorHandle.hpp:49
armnn::ClSubTensorHandle
Definition: ClTensorHandle.hpp:232
armnn::ClTensorHandle
Definition: ClTensorHandle.hpp:26
armnn::ClTensorHandle::SetImportEnabledFlag
void SetImportEnabledFlag(bool importEnabledFlag)
Definition: ClTensorHandle.hpp:117
armnn::ClTensorHandle::GetParent
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
Definition: ClTensorHandle.hpp:85
armnn::ClTensorHandle::ClTensorHandle
ClTensorHandle(const TensorInfo &tensorInfo, DataLayout dataLayout, MemorySourceFlags importFlags=static_cast< MemorySourceFlags >(MemorySource::Undefined))
Definition: ClTensorHandle.hpp:37
armnn::ClTensorHandle::CanBeImported
virtual bool CanBeImported(void *memory, MemorySource source) override
Implementations must determine if this memory block can be imported.
Definition: ClTensorHandle.hpp:133
armnn::ClTensorHandle::Unmap
virtual void Unmap() const override
Unmap the tensor data.
Definition: ClTensorHandle.hpp:83
armnn::ClSubTensorHandle::SetMemoryGroup
virtual void SetMemoryGroup(const std::shared_ptr< arm_compute::IMemoryGroup > &) override
Definition: ClTensorHandle.hpp:263
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::ClTensorHandle::Map
virtual const void * Map(bool blocking=true) const override
Map the tensor data for access.
Definition: ClTensorHandle.hpp:77
armnn::ClTensorHandle::Import
virtual bool Import(void *memory, MemorySource source) override
Import externally allocated memory.
Definition: ClTensorHandle.hpp:122
armnn::ClTensorHandle::GetShape
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
Definition: ClTensorHandle.hpp:102
armnn::ITensorHandle
Definition: ITensorHandle.hpp:15
armnn::UnimplementedException
Definition: Exceptions.hpp:98
armnn::ClTensorHandle::SetImportFlags
void SetImportFlags(MemorySourceFlags importFlags)
Definition: ClTensorHandle.hpp:107
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::Half
half_float::half Half
Definition: Half.hpp:22
armnn::ClTensorHandle::GetTensor
arm_compute::CLTensor & GetTensor() override
Definition: ClTensorHandle.hpp:47
ArmComputeTensorUtils.hpp
armnn::MemorySource
MemorySource
Define the Memory Source to reduce copies.
Definition: Types.hpp:241
armnn::ClSubTensorHandle::Unmap
virtual void Unmap() const override
Unmap the tensor data.
Definition: ClTensorHandle.hpp:254
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::ClSubTensorHandle::GetParent
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
Definition: ClTensorHandle.hpp:256
ArmComputeTensorHandle.hpp
armnn::Compute::Undefined
@ Undefined
armnn::ClTensorHandle::Manage
virtual void Manage() override
Indicate to the memory manager that this resource is active.
Definition: ClTensorHandle.hpp:63
Half.hpp
IClTensorHandle.hpp
armnn::DataType
DataType
Definition: Types.hpp:48
armnn::ClSubTensorHandle::ClSubTensorHandle
ClSubTensorHandle(IClTensorHandle *parent, const arm_compute::TensorShape &shape, const arm_compute::Coordinates &coords)
Definition: ClTensorHandle.hpp:235
armnn::ClTensorHandle::ClTensorHandle
ClTensorHandle(const TensorInfo &tensorInfo)
Definition: ClTensorHandle.hpp:29
armnn::ClSubTensorHandle::GetTensor
arm_compute::CLSubTensor const & GetTensor() const override
Definition: ClTensorHandle.hpp:244
armnn::ClTensorHandle::GetTensor
arm_compute::CLTensor const & GetTensor() const override
Definition: ClTensorHandle.hpp:48
armnn::ClTensorHandle::GetImportFlags
MemorySourceFlags GetImportFlags() const override
Get flags describing supported import sources.
Definition: ClTensorHandle.hpp:112
armnn::ClTensorHandle::GetStrides
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
Definition: ClTensorHandle.hpp:97
armnn::ClSubTensorHandle::GetShape
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
Definition: ClTensorHandle.hpp:270
armnn::MemorySourceFlags
unsigned int MemorySourceFlags
Definition: MemorySources.hpp:15
armnn::Coordinates
std::array< unsigned int, MaxNumOfTensorDimensions > Coordinates
Definition: InternalTypes.hpp:15
armnn::ClSubTensorHandle::GetStrides
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
Definition: ClTensorHandle.hpp:265