ArmNN
 21.11
TensorHandle.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
10 
11 #include <armnn/TypesUtils.hpp>
12 
13 #include <CompatibleTypes.hpp>
14 
15 #include <algorithm>
16 
17 #include <armnn/utility/Assert.hpp>
18 
19 namespace armnn
20 {
21 
22 // Get a TensorShape representing the strides (in bytes) for each dimension
23 // of a tensor, assuming fully packed data with no padding
24 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
25 
26 // Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
28 {
29 public:
30  template <typename T>
31  const T* GetConstTensor() const
32  {
33  ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
34  return reinterpret_cast<const T*>(m_Memory);
35  }
36 
37  const TensorInfo& GetTensorInfo() const
38  {
39  return m_TensorInfo;
40  }
41 
42  virtual void Manage() override {}
43 
44  virtual ITensorHandle* GetParent() const override { return nullptr; }
45 
46  virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
47  virtual void Unmap() const override {}
48 
49  TensorShape GetStrides() const override
50  {
51  return GetUnpaddedTensorStrides(m_TensorInfo);
52  }
53  TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
54 
55 protected:
56  ConstTensorHandle(const TensorInfo& tensorInfo);
57 
58  void SetConstMemory(const void* mem) { m_Memory = mem; }
59 
60 private:
61  // Only used for testing
62  void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
63  void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
64 
65  ConstTensorHandle(const ConstTensorHandle& other) = delete;
66  ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;
67 
68  TensorInfo m_TensorInfo;
69  const void* m_Memory;
70 };
71 
72 template<>
73 const void* ConstTensorHandle::GetConstTensor<void>() const;
74 
75 // Abstract specialization of ConstTensorHandle that allows write access to the same data.
77 {
78 public:
79  template <typename T>
80  T* GetTensor() const
81  {
82  ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
83  return reinterpret_cast<T*>(m_MutableMemory);
84  }
85 
86 protected:
87  TensorHandle(const TensorInfo& tensorInfo);
88 
89  void SetMemory(void* mem)
90  {
91  m_MutableMemory = mem;
92  SetConstMemory(m_MutableMemory);
93  }
94 
95 private:
96 
97  TensorHandle(const TensorHandle& other) = delete;
98  TensorHandle& operator=(const TensorHandle& other) = delete;
99  void* m_MutableMemory;
100 };
101 
102 template <>
103 void* TensorHandle::GetTensor<void>() const;
104 
105 // A TensorHandle that owns the wrapped memory region.
107 {
108 public:
109  explicit ScopedTensorHandle(const TensorInfo& tensorInfo);
110 
111  // Copies contents from Tensor.
112  explicit ScopedTensorHandle(const ConstTensor& tensor);
113 
114  // Copies contents from ConstTensorHandle
115  explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);
116 
118  ScopedTensorHandle& operator=(const ScopedTensorHandle& other);
120 
121  virtual void Allocate() override;
122 
123 private:
124  // Only used for testing
125  void CopyOutTo(void* memory) const override;
126  void CopyInFrom(const void* memory) override;
127 
128  void CopyFrom(const ScopedTensorHandle& other);
129  void CopyFrom(const void* srcMemory, unsigned int numBytes);
130 };
131 
132 // A TensorHandle that wraps an already allocated memory region.
133 //
134 // Clients must make sure the passed in memory region stays alive for the lifetime of
135 // the PassthroughTensorHandle instance.
136 //
137 // Note there is no polymorphism to/from ConstPassthroughTensorHandle.
139 {
140 public:
141  PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
142  : TensorHandle(tensorInfo)
143  {
144  SetMemory(mem);
145  }
146 
147  virtual void Allocate() override;
148 };
149 
150 // A ConstTensorHandle that wraps an already allocated memory region.
151 //
152 // This allows users to pass in const memory to a network.
153 // Clients must make sure the passed in memory region stays alive for the lifetime of
154 // the PassthroughTensorHandle instance.
155 //
156 // Note there is no polymorphism to/from PassthroughTensorHandle.
158 {
159 public:
160  ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
161  : ConstTensorHandle(tensorInfo)
162  {
163  SetConstMemory(mem);
164  }
165 
166  virtual void Allocate() override;
167 };
168 
169 
170 // Template specializations.
171 
172 template <>
173 const void* ConstTensorHandle::GetConstTensor() const;
174 
175 template <>
176 void* TensorHandle::GetTensor() const;
177 
179 {
180 
181 public:
182  explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
183  : m_Mapped(false)
184  , m_TensorHandle(std::move(ptr)) {};
185 
186  /// RAII Managed resource Unmaps MemoryArea once out of scope
187  const void* Map(bool blocking = true)
188  {
189  if (m_TensorHandle)
190  {
191  auto pRet = m_TensorHandle->Map(blocking);
192  m_Mapped = true;
193  return pRet;
194  }
195  else
196  {
197  throw armnn::Exception("Attempting to Map null TensorHandle");
198  }
199 
200  }
201 
202  // Delete copy constructor as it's unnecessary
203  ManagedConstTensorHandle(const ConstTensorHandle& other) = delete;
204 
205  // Delete copy assignment as it's unnecessary
206  ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;
207 
208  // Delete move assignment as it's unnecessary
209  ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
210 
212  {
213  // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
214  if (m_TensorHandle)
215  {
216  Unmap();
217  }
218  }
219 
220  void Unmap()
221  {
222  // Only unmap if mapped and TensorHandle exists.
223  if (m_Mapped && m_TensorHandle)
224  {
225  m_TensorHandle->Unmap();
226  m_Mapped = false;
227  }
228  }
229 
230  const TensorInfo& GetTensorInfo() const
231  {
232  return m_TensorHandle->GetTensorInfo();
233  }
234 
235  bool IsMapped() const
236  {
237  return m_Mapped;
238  }
239 
240 private:
241  bool m_Mapped;
242  std::shared_ptr<ConstTensorHandle> m_TensorHandle;
243 };
244 
245 using ConstCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstCpuTensorHandle is deprecated, "
246  "use ConstTensorHandle instead", "22.05") = ConstTensorHandle;
247 using CpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("CpuTensorHandle is deprecated, "
248  "use TensorHandle instead", "22.05") = TensorHandle;
249 using ScopedCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ScopedCpuTensorHandle is deprecated, "
250  "use ScopedTensorHandle instead", "22.05") = ScopedTensorHandle;
251 using PassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("PassthroughCpuTensorHandle is deprecated, use "
252  "PassthroughTensorHandle instead",
253  "22.05") = PassthroughTensorHandle;
254 using ConstPassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstPassthroughCpuTensorHandle is "
255  "deprecated, use ConstPassthroughTensorHandle "
256  "instead", "22.05") = ConstPassthroughTensorHandle;
257 
258 } // namespace armnn
virtual const void * Map(bool) const override
Map the tensor data for access.
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
ConstPassthroughTensorHandle(const TensorInfo &tensorInfo, const void *mem)
void SetMemory(void *mem)
virtual void Allocate()=0
Indicate to the memory manager that this resource is no longer active.
void SetConstMemory(const void *mem)
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
const TensorInfo & GetTensorInfo() const
Copyright (c) 2021 ARM Limited and Contributors.
const TensorInfo & GetTensorInfo() const
ConstTensorHandle(const TensorInfo &tensorInfo)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:327
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
ManagedConstTensorHandle(std::shared_ptr< ConstTensorHandle > ptr)
class ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use ABI stable IStrategy instead.", "22.05") ILayerVisitor
T * GetTensor() const
TensorShape GetUnpaddedTensorStrides(const TensorInfo &tensorInfo)
virtual void Manage() override
Indicate to the memory manager that this resource is active.
virtual void Unmap() const override
Unmap the tensor data.
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
TensorShape GetStrides() const override
Get the strides for each dimension ordered from largest to smallest where the smallest value is the s...
PassthroughTensorHandle(const TensorInfo &tensorInfo, void *mem)
virtual ITensorHandle * GetParent() const override
Get the parent tensor if this is a subtensor.
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
const T * GetConstTensor() const